Python Forum
How do I split a dataset into test/train/validation according to a particular group?
Thread Rating:
  • 0 Vote(s) - 0 Average
  • 1
  • 2
  • 3
  • 4
  • 5
How do I split a dataset into test/train/validation according to a particular group?
#1
I have written the following code to split my dataset into a training, validation and test dataset. The dataset consists of multiple audio files from 39 participants. To prevent data leakage, I need to ensure that the audio files of each participant are either in the test/train/validation dataset (i.e. are not divided when the dataset is split). I also want to ensure that the ratio of the participants' diagnosis is evenly split between the three groups). I know it's possible to use GroupSplitShuffle however I can't find any examples that don't use dataframes. How do I modify my code to do this?

# create storage for train, validation, test sets and their indices
train_set,valid_set,test_set = [],[],[]
X_train,X_valid,X_test = [],[],[]
y_train,y_valid,y_test = [],[],[]

# convert waveforms to array for processing
waveforms = np.array(waveforms)

# process each diagnosis separately to make sure we builf balanced train/valid/test sets 
for diagnosis_num in range(len(diagnosis_dict)):
        
    # find all indices of a single unique diagnosis
    diagnosis_indices = [index for index, diagnosis in enumerate(diagnoses) if diagnosis==diagnosis_num]
    print(diagnosis_indices)

    # seed for reproducibility 
    np.random.seed(69)
    # shuffle indicies 
    diagnosis_indices = np.random.permutation(diagnosis_indices)

    # store dim (length) of the diagnosis list to make indices
    dim = len(diagnosis_indices)

    # store indices of training, validation and test sets in 80/10/10 proportion
    # train set is first 80%
    train_indices = diagnosis_indices[:int(0.8*dim)]
    # validation set is next 10% (between 80% and 90%)
    valid_indices = diagnosis_indices[int(0.8*dim):int(0.9*dim)]
    # test set is last 10% (between 90% - end/100%)
    test_indices = diagnosis_indices[int(0.9*dim):]

    # create train waveforms/labels sets
    X_train.append(waveforms[train_indices,:])
    y_train.append(np.array([diagnosis_num]*len(train_indices),dtype=np.int32))
    # create validation waveforms/labels sets
    X_valid.append(waveforms[valid_indices,:])
    y_valid.append(np.array([diagnosis_num]*len(valid_indices),dtype=np.int32))
    # create test waveforms/labels sets
    X_test.append(waveforms[test_indices,:])
    y_test.append(np.array([diagnosis_num]*len(test_indices),dtype=np.int32))

    # store indices for each emotion set to verify uniqueness between sets 
    train_set.append(train_indices)
    valid_set.append(valid_indices)
    test_set.append(test_indices)

# concatenate, in order, all waveforms back into one array 
X_train = np.concatenate(X_train,axis=0)
X_valid = np.concatenate(X_valid,axis=0)
X_test = np.concatenate(X_test,axis=0)

# concatenate, in order, all diagnoses back into one array 
y_train = np.concatenate(y_train,axis=0)
y_valid = np.concatenate(y_valid,axis=0)
y_test = np.concatenate(y_test,axis=0)

# combine and store indices for all diagnoses train, validation, test sets to verify uniqueness of sets
train_set = np.concatenate(train_set,axis=0)
valid_set = np.concatenate(valid_set,axis=0)
test_set = np.concatenate(test_set,axis=0)

# check shape of each set
print(f'Training waveforms:{X_train.shape}, y_train:{y_train.shape}')
print(f'Validation waveforms:{X_valid.shape}, y_valid:{y_valid.shape}')
print(f'Test waveforms:{X_test.shape}, y_test:{y_test.shape}')

# make sure train, validation, test sets have no overlap/are unique
# get all unique indices across all sets and how many times each index appears (count)
uniques, count = np.unique(np.concatenate([train_set,test_set,valid_set],axis=0), return_counts=True)

# if each index appears just once, and we have 1440 such unique indices, then all sets are unique
if sum(count==1) == len(diagnoses):
    print(f'\nSets are unique: {sum(count==1)} samples out of {len(diagnoses)} are unique')
else:
    print(f'\nSets are NOT unique: {sum(count==1)} samples out of {len(diagnoses)} are unique')    
Reply
#2
Looks like you're doing that shuffle/split in lines 19-30, most of which don't appear to be dataframe specific. The actual shuffle on line 19 could be replaced with something like random.shuffle() assuming you are starting with a list.
Reply


Possibly Related Threads…
Thread Author Replies Views Last Post
  Class test : good way to split methods into several files paul18fr 4 471 Jan-30-2024, 11:46 AM
Last Post: Pedroski55
  How to fix With n_samples=0, test_size=0.2 and train_size=None, the resulting train s MrSonoa 2 2,884 Apr-15-2023, 12:02 PM
Last Post: MrSonoa
  how do I return Max Test result + corresponding student name from an excel dataset? sean1 3 1,250 Jan-16-2022, 09:07 PM
Last Post: snippsat
  How to test and import a model form computer to test accuracy using Sklearn library Anldra12 6 3,112 Jul-03-2021, 10:07 AM
Last Post: Anldra12
  How to write test cases for a init function by Unit test in python? binhduonggttn 2 3,105 Feb-24-2020, 12:06 PM
Last Post: Larz60+
  How to write test cases by Unit test for database configuration file? binhduonggttn 0 2,552 Feb-18-2020, 08:03 AM
Last Post: binhduonggttn
  Delimited Values to ROW - Lucky Train ? karthi_python 1 2,337 May-30-2019, 06:40 AM
Last Post: karthi_python
  split and test tweet data Jmekubo 1 2,140 May-08-2019, 10:48 AM
Last Post: michalmonday

Forum Jump:

User Panel Messages

Announcements
Announcement #1 8/1/2020
Announcement #2 8/2/2020
Announcement #3 8/6/2020