Using class weight to improve class imbalance


Compile VK
Source: analytics vidhya


  • Learn how class weight optimization works and how to use sklearn to implement the same method in logistic regression or any other algorithm

  • Learn how to overcome the problem of unbalanced class data by modifying the class weight without using any sampling method


The classification problem in machine learning is that we give some inputs (independent variables) and we have to predict a discrete target. The distribution of discrete values is likely to be very different. Due to the difference of each class, the algorithm tends to be biased to most of the existing values, while the effect of dealing with few values is not good.

This difference in class frequency affects the overall predictability of the model.

It is not difficult to obtain good accuracy on these problems, but it does not mean that the model is good. We need to check whether the performance of these models has any commercial significance or value. That’s why it’s necessary to understand the problem and the data so that you can use the right metrics and optimize it in the right way.


  • What is category imbalance?

  • Why deal with category imbalance?

  • What is category weight?

  • Class weight in logistic regression

  • Python implementation

    • Simple logistic regression
    • Weighted logistic regression (‘equilibrium ‘)
    • Weighted logistic regression (manual weighting)
  • Skills to further improve scores

What is category imbalance?

Class imbalance is a problem in machine learning classification. It only shows that the frequency of the target class is highly unbalanced, that is, the frequency of one class is very high compared with other existing classes. In other words, there is a bias against most of the classes in the target.

Suppose we consider a binary classification, in which most target classes have 10000, while a few target classes have only 100. In this case, the ratio is 100:1, that is, for every 100 majority classes, there is only one minority class. This problem is what we call category imbalance. The general areas where we can find these data are fraud detection, churn prediction, medical diagnosis, email classification, etc.

We will deal with a data set in the medical field to correctly understand the class imbalance. Here, we have to predict whether a person will suffer from heart disease according to a given attribute (independent variable). In order to skip data cleaning and preprocessing, we use the cleaned version of the data.

In the image below, you can see the distribution of the target variables.

#Draw a bar chart of the target
g = sns.barplot(data['stroke'], data['stroke'], palette='Set1', estimator=lambda x: len(x) / len(data) )

#Statistics of Graphs
for p in g.patches:
        width, height = p.get_width(), p.get_height()
        x, y = p.get_xy() 

#Set label
plt.xlabel('Heart Stroke', fontsize=14)
plt.ylabel('Precentage', fontsize=14)
plt.title('Percentage of patients will/will not have heart stroke', fontsize=16)

ad locum,

0: the patient has no heart disease.

1: It means the patient has a heart attack.

As can be seen from the distribution, only 2% of patients have heart disease. So, this is a classic category imbalance problem.

Why deal with category imbalance?

So far, we have intuitions about category imbalances. But why do we need to overcome this problem, and what problems will arise when modeling with these data?

Most machine learning algorithms assume that the data are evenly distributed in the class. Among the class imbalance problems, the broad problem is that the algorithm will be more inclined to predict most classes (in our case, no heart disease). The algorithm does not have enough data to learn patterns in a few classes (heart disease).

Let’s take a real life example to better understand this.

Suppose you have moved from your hometown to a new city where you have lived for a month. When you come to your hometown, you will be very familiar with all the places, such as your home, routes, important shops, tourist attractions and so on, because you spent your whole childhood there.

But when you go to a new city, you don’t have too many ideas about the specific location of each place. The probability of getting lost by taking the wrong route is very high. Here, your hometown is the majority of your class, and Xincheng is the minority.

Again, this can happen in category imbalances. A few classes don’t have enough information about your class, which is why a few classes have high misclassification errors.

Note: to check the performance of the model, we will use the F1 score as a measure rather than accuracy.

The reason is that if we build a stupid model to predict that every new training data is 0 (no heart disease), even then, we will get very high accuracy, because the model is biased towards most classes.

Here, the model is very accurate, but it has no value for our problem statement. That’s why we will use the F1 score as an indicator. The F1 score is just a harmonic average of precision and recall. However, metrics are chosen based on business issues and the types of errors we want to reduce. However, F1 score is the key to measure category imbalance.

Here is the formula for the F1 score:

f1 score = 2*(precision*recall)/(precision+recall)

Let’s confirm this by training a model based on the target variable pattern, and check our scores:

#Training model with target pattern
from sklearn.metrics import f1_score, accuracy_score, confusion_matrix
pred_test = []
for i in range (0, 13020):

#Print F1 and accuracy scores
print('The accuracy for mode model is:', accuracy_score(y_test, pred_test))
print('The f1 score for the model model is:',f1_score(y_test, pred_test))

#Draw confusion matrix
conf_matrix(y_test, pred_test)

The accuracy of the model is 0.9819508448540707

The F1 score of the model is 0.0

Here, the accuracy of the model to the test data is 0.98, which is a good score. On the other hand, the score of F1 is zero, which indicates that the model does not perform well in minority groups. We can confirm this by looking at the confusion matrix.

The model predicted 0 for each patient (no heart disease). According to this model, no matter what symptoms a patient has, he or she will never have a heart attack. Does it make sense to use this model?

Now that we understand what class imbalance is and how it affects the performance of our model, we’ll shift our focus to what class weights are and how class weights help improve model performance.

What is the category weight?

Most machine learning algorithms are not very useful for biased class data. However, we can modify the existing training algorithm to consider the skew distribution of classes. This can be achieved by giving different weights to the majority and the minority. In the training stage, the difference of weight will affect the classification of categories. The overall purpose is to punish the misclassification of a few classes by setting higher class weights and reducing the weights for most classes.

To illustrate this more clearly, we will revert to the city example we have considered before.

Think about it this way. You spent the last month in this new city, instead of going out when you needed to, you spent a whole month exploring the city. Throughout the month you spend more time learning about the route and location of the city. Giving you more time to study will help you better understand the new city and reduce the chance of getting lost. This is how class weights work.

In the process of training, we give more weight to the minority class in the cost function of the algorithm, so that it can provide higher punishment to the minority class, so that the algorithm can focus on reducing the error of the minority class.

Note: there is a threshold that you should increase and decrease the class weights of minority and majority classes respectively. If a few classes are given a very high class weight, the algorithm is likely to favor a few classes and increase the errors in most classes.

Most sklearn classifier modeling libraries, even some boosting based libraries, such as lightgbm and catboost, have a built-in parameter “class”_ This helps us optimize scores for a few categories, as we’ve learned so far.

By default, class_ The value of weights is “None”, that is, the weights of the two classes are equal. In addition, we can give it “balanced” or pass it a dictionary containing the artificial design weights of two classes.

When class weight is equal to “balance”, the model will automatically assign class weights which are inversely proportional to their respective frequencies.

More precisely, the calculation formula is as follows:

wj=n_samples / (n_classes * n_samplesj)

ad locum,

  • WJ is the weight of each class (J is the class)

  • n_ Samples is the total number of samples or rows in the dataset

  • n_ Classes is the total number of unique classes in the target

  • n_ Samplesj is the total number of rows of the corresponding class

For our heart case:

n_samples= 43400, n_classes= 2(0&1), n_sample0= 42617, n_samples1= 783

Weight of Class 0:

w0=  43400/(2*42617) = 0.509

Weight of category 1:

w1= 43400/(2*783) = 27.713

I hope this will make it clearer that category weight =’balanced ‘helps us give higher weight to a few categories and lower weight to most categories.

While passing values as “balanced” produces good results in most cases, sometimes we can try to design weights for extreme class imbalances. Later we’ll see how to find the best value for class weights in Python.

Class weight in logistic regression

We can modify each machine learning algorithm by adding different class weights to the cost function of the algorithm, but here we will pay special attention to logistic regression.

For logistic regression, we use logarithmic loss as the cost function. We did not use the mean square error as the cost function of logistic regression, because we used sigmoid curve as the prediction function instead of fitting the straight line.

Flattening the sigmoid function leads to a nonconvex curve, which makes the cost function have a large number of local minima. But the logarithm loss is a convex function, we only have a minimum value to converge.

Log loss formula:

ad locum,

  • N is the number of values

  • Yi is the actual value of the target class

  • Yi is the prediction probability of the target class

Let’s form a pseudo table, which contains the actual forecast, the forecast probability and the cost calculated using the log loss formula:

In this table, we have 10 observations, nine of which come from category 0 and nine from category 1. In the next column, we’ll give the prediction probability for each observation. Finally, using the logarithmic loss formula, we get the cost penalty.

After adding the weight into the cost function, the modified logarithmic loss function is as follows:


W0 is the class weight of Class 0

W1 is the class weight of class 1

Now, we’ll add weights to see how it affects cost penalties.

For weight values, we will use class_ The weights’ balanced ‘formula.

w0= 10/(2*1) = 5

w1= 10/(2*9) = 0.55

Calculate the cost of the first value in the table:

Cost = -(5(0*log(0.32) + 0.55(1-0)*log(1-0.32))

= -(0 + 0.55*log(.68))

= -(0.55*(-0.385))

= 0.211

Similarly, we can calculate the weighted cost of each observation

Through this table, we can determine that the cost function of most classes is applied with a smaller weight, which leads to a smaller error value, thus reducing the updating of model coefficients. A larger weight value is applied to the cost function of a few classes, which will lead to greater error calculation and more update of model coefficients. In this way, we can change the deviation of the model, so as to reduce the error of a few classes.


Smaller weight will lead to smaller penalty and smaller updating of model coefficients

A larger weight will lead to a larger penalty and a large number of updating of model coefficients

Python implementation

Here, we will use the same heart disease data to predict. First, we will train a simple logistic regression, and then we will implement weighted logistic regression with class weight of “balance”. Finally, we will try to use grid search to find the best value of class weight. The target we are trying to optimize will be the F1 score.

Simple logistic regression: 1

Here, we use the sklearn library to train our model, and we use the default logistic regression. By default, the algorithm gives equal weight to two classes.

#Import and training model
from sklearn.linear_model import LogisticRegression
lr = LogisticRegression(solver='newton-cg'), y_train)

#Test data prediction
pred_test = lr.predict(x_test)

#Calculate and print the F1 score
f1_test = f1_score(y_test, pred_test)
print('The f1 score for the testing data:', f1_test)

#Function to create confusion matrix
def conf_matrix(y_test, pred_test):    
    #Create confusion matrix
    con_mat = confusion_matrix(y_test, pred_test)
    con_mat = pd.DataFrame(con_mat, range(2), range(2))
    sns.heatmap(con_mat, annot=True, annot_kws={"size": 16}, fmt='g', cmap='Blues', cbar=False)
#Call function
conf_matrix(y_test, pred_test)

Test data F1 score: 0.0

In simple logistic regression model, F1 score was 0. By looking at the confusion matrix, we can confirm that our model predicts every observation, because heart disease does not occur. This model is no better than the pattern model we created earlier. Let’s try to add some weight to a few classes to see if that helps.

2. Logistic regression (class)_ weight=’balanced’):

We add the class weight parameter to the logistic regression algorithm, and the value is “balanced”.

#Import and training model
from sklearn.linear_model import LogisticRegression
lr = LogisticRegression(solver='newton-cg', class_weight='balanced'), y_train)

#Test data prediction
pred_test = lr.predict(x_test)

#Calculate and print the F1 score
f1_test = f1_score(y_test, pred_test)
print('The f1 score for the testing data:', f1_test)

#Draw confusion matrix
conf_matrix(y_test, pred_test)

Test data F1 score: 0.10098851188885921

By adding a single class weight parameter to the logistic regression function, we improved the F1 score by 10%. We can see in the confusion matrix that although the misclassification of Class 0 (no heart disease) increases, the model can capture class 1 (heart disease) well.

Can we further improve the measurement by changing the class weight?

3. Logistic regression (setting class weight manually)

Finally, we try to use grid search to find the best weight with the highest score. We will search for weights between 0 and 1. Our idea is that if we give a few categories n as weights, most categories will get 1-N as weights.

Here, the weight is not very large, but the weight ratio between most categories and a few categories will be very high.

For example:

w1 = 0.95

w0 = 1 – 0.95 = 0.05

w1:w0 = 19:1

As a result, a few categories will weigh 19 times more than most.

from sklearn.model_selection import GridSearchCV, StratifiedKFold
lr = LogisticRegression(solver='newton-cg')

#Set the range of class weights
weights = np.linspace(0.0,0.99,200)

#Create dictionary grid for grid search
param_grid = {'class_weight': [{0:x, 1:1.0-x} for x in weights]}

##Fitting training data with 5 times grid search method
gridsearch = GridSearchCV(estimator= lr, 
                          param_grid= param_grid,
                          verbose=2).fit(x_train, y_train)

#Draw fractions with different weight values
weigh_data = pd.DataFrame({ 'score': gridsearch.cv_results_['mean_test_score'], 'weight': (1- weights)})
sns.lineplot(weigh_data['weight'], weigh_data['score'])
plt.xlabel('Weight for class 1')
plt.ylabel('F1 score')
plt.xticks([round(i/10,1) for i in range(0,11,1)])
plt.title('Scoring for different class weights', fontsize=24)

From the figure, we can see that the highest value of a few classes reaches the peak at 0.93.

Through grid search, we get the best class weight, 0 class (majority class) is 0.06467, 1 class (minority class) is 1: 0.93532.

Now that we’ve used hierarchical cross validation and grid search to get the best class weights, we’ll see the performance of the test data.

#Import and training model
from sklearn.linear_model import LogisticRegression
lr = LogisticRegression(solver='newton-cg', class_weight={0: 0.06467336683417085, 1: 0.9353266331658292}), y_train)

#Test data prediction
pred_test = lr.predict(x_test)

#Calculate and print the F1 score
f1_test = f1_score(y_test, pred_test)
print('The f1 score for the testing data:', f1_test)

#Draw confusion matrix
conf_matrix(y_test, pred_test)

F1 score: 0.15714644

By manually changing the weight value, we can further improve the F1 score by about 6%. The confusion matrix also shows that from the previous model, we can better predict Class 0, but at the cost of our class 1 misclassification. It all depends on the business problem or the type of error you want to reduce. Here, our focus is to improve the F1 score, we can do this by adjusting the category weight.

Further improve scoring skills

Feature Engineering: for simplicity, we only use the given arguments. You can try to create new features

Adjust threshold: by default, the threshold of all algorithms is 0.5. You can try different threshold values, and you can find the best value by using grid search or randomized search

Use advanced algorithms: for this explanation, we only use logistic regression. You can try different bagging and boosting algorithms. Finally, we can try to mix a variety of algorithms


I hope this article will give you an idea of how class weighting helps to deal with class imbalance and how easy it is to implement in Python.

Although we have discussed how class weight can only be applied to logistic regression, the ideas of other algorithms are the same; It is only the change of the cost function used by each algorithm to minimize errors and optimize a few kinds of results

Link to the original text:

Welcome to panchuang AI blog:

Sklearn machine learning official Chinese document:

Welcome to pancreato blog Resource Hub: