Devashish Prasad



Handling class imbalanced training data

by Devashish Prasad
September 8, 2021



First, I prepare a common validation set out of my dataset and try out various methods. I evaluate performance of all these methods on the validation set. To start with I use the following methods to handle the imbalanced datasets. -

1. Using weighted cross-entropy loss

We can assign weights to the cross-entropy loss such that it will penalize more to the smaller classes and the less to larger classes. Many frameworks have a very easy way to do this.
In Scikit-learn we can look out for class_weight parameter. For eg - random forest
Here is how we can use this in Pytorch
Here is how we can use this in Keras

2. Using focal loss

Originally proposed for object detection, but we can also use this for any other use case. Read more about it here
Here is how you can use this in Pytorch for multi-class classification
Here is how you can use this in Keras

3. Over Sampling and Under Sampling

There are so many techniques in this, check out imblearn a dedicated library just to deal with imbalanced datasets.

4. Create a separate model for small classes

If you have some classes that have very small number of instances, you can consider creating a separate classifier for these small classes (called small_classifier for eg). You can group together these small clases under a single class (called small_class for eg) so that your main classifier will classify small_class with all other big classes in the dataset. And if your main classifier encounters any instance of small_class, it will pass it to small_classifier, which will predict the actual class for the small_class instance. This technique can give you accuracy boosts are now main classifier does not need to deal with very small classes, and insted small_classifier will be looking just at these small classes.


Thank you for reading! Please comment your thoughts below