ML practitioners wrangle data to create the best training datasets possible for their models. This generally means that the data is unbiased, well-structured, and will help to train a model that gives good predictive performance in the real world.
Whether you collect data yourself from the real world, use an existing dataset, or use an existing trained model as part of transfer learning, the underlying data can contain inherent biases that lead to suboptimal, unbalanced data.
In a nutshell, an unbalanced dataset contains lots of data for one or more specific classifications (called majority classes) and little data for others (minority classes). This can be caused by data collection issues or by the inherent nature of the data itself. For example, consider a dataset comprising images of valid versus fraudulent insurance claim forms for use in image classification. Since the vast majority of claims are likely valid, most of the images would represent valid claims (majority classes) while very few would represent fraudulent claims (minority classes).
Using unbalanced data like this for training can affect the trained model's predictive performance through poor generalization (i.e., inability to accurately classify real-world data), overfitting, and underfitting. This happens because the algorithm becomes biased to the majority classes, and doesn't really learn how to distinguish the patterns necessary to differentiate them from the minority classes. And since we often judge the performance of ML models based on their accuracy, this leads to the accuracy paradox, that is, the model's accuracy becomes based on the underlying (unbalanced) class distribution rather than on a learnt differentiation pattern, and thus appears better than it really is.
Before working with a given dataset, it's a good idea to first determine if it is indeed unbalanced. Tools like scikit-learn for example, can help you gain that initial insight into your data. Once you know you have a data imbalance, there are a few approaches you can take to mitigate it.
ML practitioners can sometimes overcome unbalanced data by creating more balanced datasets. Here are a few of the more common approaches:
- Create Synthetic Data: simply put, this involves generating new data samples for underrepresented/minority classes, often basing the new samples on existing data.
- Data Augmentation: similar to the create-synthetic-data technique, data augmentation creates new data samples by copying and modifying existing underrepresented/minority classes. This approach was used in our Fabric Stain Classification use case where we applied random rotations and vertical/horizontal flips to generate new images.
- Remove Samples from Overrepresented Classes: if you have a fairly large dataset, you might be able to simply remove some of the overrepresented samples to even out the numbers. This approach was used in our Retinal OCT use case where the source dataset contained over 80,000 images, and we were able to reduce the number of samples to 4000 per classification while still having a sufficiently large dataset.
- Collect More Data: it may seem obvious, but sometimes you just need to get more data. This can involve manual collection processes, concatenating multiple datasets (e.g., data collected from sensors), or augmenting your existing data with open-source data sets online. Check out PerceptiLabs' Top 5 Open Source Datasets for Machine Learning for some suggested online sources.
When your situation requires you to work with unbalanced data distributions, it can pay to look at the problem differently.
Take the insurance fraud claim detection example above. Instead of treating it as a binary classification problem (i.e., classifying valid versus fraudulent claims), you might instead consider it as an anomaly detection problem. With this in mind, you might be able to repartition the data in new ways, such as by different types of fraud, while at the same time, reducing the number of data samples for valid claims. Such a strategy could then lead to new distributions, new models, and possibly new insights into the data.
Don't Get Confused
When training models with unbalanced datasets, it's best not to assess those models on accuracy alone. One handy tool you can use is Perceptilabs' confusion matrix as seen in Figure 1:
A confusion matrix lets you see, at a glance, how predicted classifications compare to actual classifications, on a per-class level. It's available in Perceptilabs' model testing screen, after you've fully trained your model. Perceptilabs augments the standard confusion matrix by displaying colors representing the comparisons. The colors correspond to the number of samples that are classified as one thing or another. For example, if you have many samples for one class and few for another, then you have an unbalanced test dataset. If the samples are not on the diagonal, then they are falsely classified as another class, making it easy to see if any one class has better classifications than any other. You can see the confusion matrix used in our Detecting Defective Pills use case where it displays false positives and false negatives.
PerceptiLabs also displays a number of useful label metrics to the right of the matrix that can help to further explain the results:
- Categorical Accuracy: accuracy for each category averaged over all of them.
- Top K Categorical Accuracy: frequency of the correct category among the top K predicted categories.
- Precision: accuracy of positive predictions.
- Recall: percentage of positives found (i.e., not misclassified as negatives instead of positives).
Precision and Recall often are more telling than the accuracy and generally you want both to be as high as possible. They help you find as many of the classes as possible that are correctly classified as often as possible.
You can use these tools to better assess your model and gain deeper insight into how it handles unbalanced data. From there you can make decisions on how to move forward with your model.
PerceptiLabs makes it easy to experiment with different models even when training with suboptimal, unbalanced datasets. Our visual interface allows you to easily drag and drop different Components into your mode to modify it, and our rich statistics and test views give you a wealth of information about model performance in real time. And in many cases, simply updating hyperparameters (e.g., batch size, bias, etc.) in PerceptiLabs' UI or swapping different neural network Components will do the trick.
Another option is to employ ensemble learning by building two parallel models in PerceptiLabs' workspace and then merging them together into a single output to the model's Target. While ensemble learning doesn't directly help with unbalanced datasets, it can help make a model more robust and hopefully reduce overfitting.
And don't forget, PerceptiLabs' Model Hub makes it easy to create, train, and compare multiple models. This can be useful for comparing results, such as those provided by the confusion matrix and test metrics.