Skip to content

Use Weighted Focal Loss

When to use Focal Loss?

Focal Loss addresses class imbalance in tasks such as object detection. Focal loss applies a modulating term to the Cross Entropy loss in order to focus learning on hard negative examples. It is a dynamically scaled Cross Entropy loss, where the scaling factor decays to zero as confidence in the correct class increases. Intuitively, this scaling factor can automatically down-weight the contribution of easy examples during training and rapidly focus the model on hard examples. This scaling factor is gamma. The more gamma is increased, the more the model is focussed on the hard, misclassified examples.

We employ Weighted Focal Loss, which further allows us to reduce false positives or false negatives depending on our value of alpha:

A value alpha > 1 decreases the false negative count, hence increasing the recall. Conversely, setting alpha < 1 decreases the false positive count and increases the precision.

from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score
from bokbokbok.loss_functions.classification import WeightedFocalLoss
from bokbokbok.eval_metrics.classification import WeightedFocalMetric
from bokbokbok.utils import clip_sigmoid

X, y = make_classification(n_samples=1000, 
                           n_features=10, 
                           random_state=41114)

X_train, X_valid, y_train, y_valid = train_test_split(X, 
                                                      y, 
                                                      test_size=0.25, 
                                                      random_state=41114)

alpha = 0.7  # Reduce False Positives
gamma = 2    # Focus on misclassified examples more strictly

Usage in LightGBM

import lightgbm as lgb

train = lgb.Dataset(X_train, y_train)
valid = lgb.Dataset(X_valid, y_valid, reference=train)
params = {
     'n_estimators': 300,
     'seed': 41114,
     'n_jobs': 8,
     'learning_rate': 0.1,
   }

clf = lgb.train(params=params,
                train_set=train,
                valid_sets=[train, valid],
                valid_names=['train','valid'],
                fobj=WeightedFocalLoss(alpha=alpha, gamma=gamma),
                feval=WeightedFocalMetric(alpha=alpha, gamma=gamma),
                early_stopping_rounds=100)

roc_auc_score(y_valid, clip_sigmoid(clf.predict(X_valid)))

Usage in XGBoost

import xgboost as xgb

dtrain = xgb.DMatrix(X_train, y_train)
dvalid = xgb.DMatrix(X_valid, y_valid)

params = {
     'seed': 41114,
     'learning_rate': 0.1,
    'disable_default_eval_metric': 1
   }

bst = xgb.train(params,
          dtrain=dtrain,
          num_boost_round=300,
          early_stopping_rounds=10,
          verbose_eval=10,
          obj=WeightedFocalLoss(alpha=alpha, gamma=gamma),
          maximize=False,
          feval=WeightedFocalMetric(alpha=alpha, gamma=gamma, XGBoost=True),
          evals=[(dtrain, 'dtrain'), (dvalid, 'dvalid')])

roc_auc_score(y_valid, clip_sigmoid(bst.predict(dvalid)))