Python Forum

Full Version: How do I split a dataset into test/train/validation according to a particular group?
You're currently viewing a stripped down version of our content. View the full version with proper formatting.
I have written the following code to split my dataset into a training, validation and test dataset. The dataset consists of multiple audio files from 39 participants. To prevent data leakage, I need to ensure that the audio files of each participant are either in the test/train/validation dataset (i.e. are not divided when the dataset is split). I also want to ensure that the ratio of the participants' diagnosis is evenly split between the three groups). I know it's possible to use GroupSplitShuffle however I can't find any examples that don't use dataframes. How do I modify my code to do this?

# create storage for train, validation, test sets and their indices
train_set,valid_set,test_set = [],[],[]
X_train,X_valid,X_test = [],[],[]
y_train,y_valid,y_test = [],[],[]

# convert waveforms to array for processing
waveforms = np.array(waveforms)

# process each diagnosis separately to make sure we builf balanced train/valid/test sets 
for diagnosis_num in range(len(diagnosis_dict)):
        
    # find all indices of a single unique diagnosis
    diagnosis_indices = [index for index, diagnosis in enumerate(diagnoses) if diagnosis==diagnosis_num]
    print(diagnosis_indices)

    # seed for reproducibility 
    np.random.seed(69)
    # shuffle indicies 
    diagnosis_indices = np.random.permutation(diagnosis_indices)

    # store dim (length) of the diagnosis list to make indices
    dim = len(diagnosis_indices)

    # store indices of training, validation and test sets in 80/10/10 proportion
    # train set is first 80%
    train_indices = diagnosis_indices[:int(0.8*dim)]
    # validation set is next 10% (between 80% and 90%)
    valid_indices = diagnosis_indices[int(0.8*dim):int(0.9*dim)]
    # test set is last 10% (between 90% - end/100%)
    test_indices = diagnosis_indices[int(0.9*dim):]

    # create train waveforms/labels sets
    X_train.append(waveforms[train_indices,:])
    y_train.append(np.array([diagnosis_num]*len(train_indices),dtype=np.int32))
    # create validation waveforms/labels sets
    X_valid.append(waveforms[valid_indices,:])
    y_valid.append(np.array([diagnosis_num]*len(valid_indices),dtype=np.int32))
    # create test waveforms/labels sets
    X_test.append(waveforms[test_indices,:])
    y_test.append(np.array([diagnosis_num]*len(test_indices),dtype=np.int32))

    # store indices for each emotion set to verify uniqueness between sets 
    train_set.append(train_indices)
    valid_set.append(valid_indices)
    test_set.append(test_indices)

# concatenate, in order, all waveforms back into one array 
X_train = np.concatenate(X_train,axis=0)
X_valid = np.concatenate(X_valid,axis=0)
X_test = np.concatenate(X_test,axis=0)

# concatenate, in order, all diagnoses back into one array 
y_train = np.concatenate(y_train,axis=0)
y_valid = np.concatenate(y_valid,axis=0)
y_test = np.concatenate(y_test,axis=0)

# combine and store indices for all diagnoses train, validation, test sets to verify uniqueness of sets
train_set = np.concatenate(train_set,axis=0)
valid_set = np.concatenate(valid_set,axis=0)
test_set = np.concatenate(test_set,axis=0)

# check shape of each set
print(f'Training waveforms:{X_train.shape}, y_train:{y_train.shape}')
print(f'Validation waveforms:{X_valid.shape}, y_valid:{y_valid.shape}')
print(f'Test waveforms:{X_test.shape}, y_test:{y_test.shape}')

# make sure train, validation, test sets have no overlap/are unique
# get all unique indices across all sets and how many times each index appears (count)
uniques, count = np.unique(np.concatenate([train_set,test_set,valid_set],axis=0), return_counts=True)

# if each index appears just once, and we have 1440 such unique indices, then all sets are unique
if sum(count==1) == len(diagnoses):
    print(f'\nSets are unique: {sum(count==1)} samples out of {len(diagnoses)} are unique')
else:
    print(f'\nSets are NOT unique: {sum(count==1)} samples out of {len(diagnoses)} are unique')    
Looks like you're doing that shuffle/split in lines 19-30, most of which don't appear to be dataframe specific. The actual shuffle on line 19 could be replaced with something like random.shuffle() assuming you are starting with a list.