Usually if you are working on real word machine learning is really common to have imbalanced dataset.
In this article we make a comparison between different methods and show the results using Ai4Prod. We use images taken from CIFAR10.
Code https://github.com/ai4prod/ai4prod_python
Dataset Train/val https://drive.google.com/file/d/1_aoSVW6OyPp1fT9VVakGX8cQNSRxqdEU/view?usp=sharing
Dataset Test https://drive.google.com/file/d/1bM1pVVT1UrFY2OnUhkb-zoxqE8uEfBGB/view?usp=sharing
Introduction
If you never heard about imbalanced dataset, i will give you a quick introduction. Imbalanced dataset , as name suggests, happens when in a classification problem usually you have very different number of examples for each classes.
For example our dataset used in this experiment is composed by 3 classes. For each class you have 300, 1200,900 number of examples.
As you can notice the second class is 4 times bigger than the first class. So this could be considered an imbalanced datasets. There is no a specific rule to use to understand what are imbalanced dataset. You can think that if all your classes have not, more or less, the same number of example, dataset is imbalanced.
The Problem
Why is important to be aware if you have an imbalanced dataset?
Let’s do some simple math. Suppose you start training with an imbalanced dataset as reported above so you have:
- Class 1 : 300 examples
- Class2 : 1200 examples
- Class3: 900 examples
Look at the code below. The code report the number of each example per class that you are feeding to the neural network after 5 epochs. The code is using pytorch ImageFolder and Dataloader.
[tensor(1500), tensor(6000), tensor(4500)]
So as you can see after 5 epochs the total number of example provided to the neural network are very different.
To give you a global explanation why this is a problem let’s think about how neural network are optimized. Usually your objective is to reduce a function called Loss, by updating the parameter of the neural network. The value of a loss function is obtained from the data provided to the neural network, so if you are feeding a lot more example of one class respect to others, in fact you are optimizing more for the major class in terms of number of example respect to the others.
So in conclusion to avoid this behavior you need to provide to the neural network, more or less, the same number of examples for each class during training.
Expertiment with Ai4Prod
Here we describe the experiment that we have done and report our result on the dataset that you can download above.
To address this problem with Ai4Prod_python repository we have adopted two approach:
- Oversampling
- Change Cross Entropy Loss with Focal Loss
N.B This two approaches are exclusive. So if you apply one you do not need to apply the other.
Oversampling
Oversampling is a technique that is acting on the Dataloader. Pratically if you have imbalanced dataset with the WeightedRandomSampler you are balancing how many example are taken from each class based on number of examples.
sampler= WeightedRandomSampler(self.sampleWeight,num_samples=len(self.sampleWeight),replacement=True)
Now if you use the code above to quantify how many examples are provided to the model you will get
[tensor(4043), tensor(3915), tensor(4042)]
You can see that now for each class you feed more or less the same number of samples.
In Ai4Prod_python is very simple to use WeightedRandomSampler, you only need to use the lightining datamodule ImageFolderForClassImbalancedDataset and you do not worry about class imbalanced dataset.
Example
Focal Loss
The Focal Loss lets you penalize more the fewer samples respect to the class with most samples. I will not into the detail. You can find a good explanation at Focal Loss reference at the end of the article.
With Ai4Prod_python you can use the focal Loss just add imbalanced=True to the lightining module.
model = ImagenetTransferLearning(num_classes=3,pytorch_model=model,from_scratch=True,imbalanced=True)
With imbalanced=True you will use automatically the Focal Loss.
Results
In this section we show you how is changing the Accuracy for each class by using Confusion Matrix.
- Without Considering Imbalanced Dataset
2. Oversampling
3. Focal Loss
Conclusion
In this experiment we see that Oversampling is given the best result on our Test Dataset at least on the minority class that in our case is Airplane .
Of course if your are not satisfied with the results of one of your class after apply some imbalanced technique the only way is to acquire more data for that specific class.
There are some way to artificially augment image for only specific class in your dataset, but we will describe in another article.
Reference
Focal Loss https://towardsdatascience.com/class-imbalance-d90f985c681e