How to correctly split your dataset?

ai4prod
4 min readJun 16, 2022

--

An overview of different techniques you should know to improve your AI models

Why

Split your dataset into training and validation helps you to understand if your model as high variance(overfitting) or high bias(underfitting).

This is extremely important to understand how your model could generalize to new unseen data. If your model is in overfitting probably it cannot generalize well to new unseen data. SO is not able to make good predictions.

Having a proper validation strategy is the first step to successfully create good prediction and then business value with your AI models.

Simple Train Val split

With simple train val split you divide your dataset into training and validation

Usually you split your dataset into training with 80% and validation with 20% . Some library as Scikit performs this by using random sampling.

simple train val split

Things to have in mind

  1. You need to fix the seed() otherwise you cannot compare different training, because they use a different dataset split. Fix the seed makes random sampling always the same.
  2. If you have imbalanced dataset cross validation cannot help you to maintain same ratio. See stratified k-fold
  3. If you have small dataset is not guarantee that your validation split is representative of your training split

K-Fold Cross Validation

The idea is to split the dataset in different k-partitions. In the image below the dataset is divided into 5 partitions.

Each time you choose one partition as your validation dataset and the others partitions are your training dataset. You will train your model on each different set of partitions.

In the end you will end up with K different models.

K usually is set to [3,5,7,10,20]

  • Higher K [20] is used if you want to check model performance low bias.
  • Low K [3,5] is used if you want to build model for variable selection. So to use only some data and not others for training. Your model will have low variance.
k-fold cross validation

Advantages:

  1. By averaging the models predictions you can estimate the performance of your models on unseen data drawn from same distribution
  2. Is a widely used method to get good models for production
  3. You can create a predictions for each data in your dataset. This is called OOF(out-of-fold predictions). You can use this with different ensembling techniques(Blending, Stacking)

Problems:

  1. If you have imbalanced dataset you cannot use this. Use Stratified-kFold
  2. If you retrain a model on all dataset, then you cannot compare the performance with any models that you have trained used k-Fold. Because your models are trained on k-1, so not to the entire dataset

Stratified-kFold

Is used to preserve the ratio between different classes for every fold.

Practically if you have imbalanced dataset for example class1 has 10 examples and class2 has 100 examples. Stratified- kFold create a k folds where each fold has the same imbalanced ratio as original dataset

The idea is similar to k-fold cross validation, but each fold has the same imbalanced ratio as the original dataset.

Stratified-Kfold

If you use K-fold cross validation for imbalanced dataset you will have this setup

K fold cross validation with imbalanced classes

The initial ratio between classes is not preserved in each split(fold).

If your dataset is very big K-fold cross validation could preserve class imbalance

Advantages

  1. Preserve the class imbalanced ratio with small dataset

Bootstrap and Subsampling

Bootstrap and Subsampling are similar to K-Fold cross validation but they have not fixed fold.

Practically you take a random number of data from your dataset and use other data as validation.

You repeat this for n-times

Bootstrap= Sampling with replacement

Usually in machine learning the golden rule is to use k-fold cross validation.

When to use

Boostrap and Subsamlping could be used only if you have large standard error on evaluation metric-error. This could happen due to outliers in your dataset.

--

--

ai4prod
ai4prod

Written by ai4prod

The MLOps ecosystem for C++ developers. Prototype in Python, Deploy in C++. Easy as it sound

No responses yet