使用 AutoMM 处理类别不平衡 - Focal Loss

Open In Colab Open In SageMaker Studio Lab

在本教程中,我们将介绍如何使用 AutoMM 包中的 focal loss 进行平衡训练。Focal loss 最早在这篇论文中提出,可用于平衡难易样本以及处理类别之间不均匀的样本分布。本教程演示了如何使用 focal loss。

创建数据集

在本教程中,我们使用 shopee 数据集进行演示。Shopee 数据集包含 4 个类别,训练集中每个类别有 200 个样本。

from autogluon.multimodal.utils.misc import shopee_dataset

download_dir = "./ag_automm_tutorial_imgcls_focalloss"
train_data, test_data = shopee_dataset(download_dir)
Downloading ./ag_automm_tutorial_imgcls_focalloss/file.zip from https://automl-mm-bench.s3.amazonaws.com/vision_datasets/shopee.zip...
  0%|          | 0.00/84.0M [00:00<?, ?iB/s]
  9%|▉         | 7.88M/84.0M [00:00<00:00, 78.8MiB/s]
 19%|█▉        | 15.8M/84.0M [00:00<00:01, 36.5MiB/s]
 24%|██▍       | 20.6M/84.0M [00:00<00:01, 34.7MiB/s]
 30%|██▉       | 25.2M/84.0M [00:00<00:01, 32.7MiB/s]
 38%|███▊      | 32.2M/84.0M [00:00<00:01, 33.7MiB/s]
 43%|████▎     | 35.7M/84.0M [00:01<00:01, 32.4MiB/s]
 48%|████▊     | 40.2M/84.0M [00:01<00:01, 31.6MiB/s]
 52%|█████▏    | 43.4M/84.0M [00:01<00:01, 29.4MiB/s]
 58%|█████▊    | 48.5M/84.0M [00:01<00:01, 31.8MiB/s]
 62%|██████▏   | 51.7M/84.0M [00:01<00:01, 30.7MiB/s]
 70%|██████▉   | 58.4M/84.0M [00:01<00:00, 28.1MiB/s]
 73%|███████▎  | 61.3M/84.0M [00:01<00:00, 25.3MiB/s]
 80%|███████▉  | 67.1M/84.0M [00:02<00:00, 27.6MiB/s]
 88%|████████▊ | 73.7M/84.0M [00:02<00:00, 33.5MiB/s]
 92%|█████████▏| 77.2M/84.0M [00:02<00:00, 31.6MiB/s]
100%|█████████▉| 83.6M/84.0M [00:02<00:00, 38.8MiB/s]
100%|██████████| 84.0M/84.0M [00:02<00:00, 31.7MiB/s]

为了演示 Focal Loss 在不平衡训练数据上的有效性,我们人为地对 shopee 训练数据进行了下采样,以形成不平衡分布。

import numpy as np
import pandas as pd

ds = 1

imbalanced_train_data = []
for lb in range(4):
    class_data = train_data[train_data.label == lb]
    sample_index = np.random.choice(np.arange(len(class_data)), size=int(len(class_data) * ds), replace=False)
    ds /= 3  # downsample 1/3 each time for each class
    imbalanced_train_data.append(class_data.iloc[sample_index])
imbalanced_train_data = pd.concat(imbalanced_train_data)
print(imbalanced_train_data)

weights = []
for lb in range(4):
    class_data = imbalanced_train_data[imbalanced_train_data.label == lb]
    weights.append(1 / (class_data.shape[0] / imbalanced_train_data.shape[0]))
    print(f"class {lb}: num samples {len(class_data)}")
weights = list(np.array(weights) / np.sum(weights))
print(weights)
                                                 image  label
110  /home/ci/autogluon/docs/tutorials/multimodal/a...      0
48   /home/ci/autogluon/docs/tutorials/multimodal/a...      0
35   /home/ci/autogluon/docs/tutorials/multimodal/a...      0
53   /home/ci/autogluon/docs/tutorials/multimodal/a...      0
120  /home/ci/autogluon/docs/tutorials/multimodal/a...      0
..                                                 ...    ...
786  /home/ci/autogluon/docs/tutorials/multimodal/a...      3
755  /home/ci/autogluon/docs/tutorials/multimodal/a...      3
707  /home/ci/autogluon/docs/tutorials/multimodal/a...      3
687  /home/ci/autogluon/docs/tutorials/multimodal/a...      3
726  /home/ci/autogluon/docs/tutorials/multimodal/a...      3

[295 rows x 2 columns]
class 0: num samples 200
class 1: num samples 66
class 2: num samples 22
class 3: num samples 7
[np.float64(0.0239850482815907), np.float64(0.07268196448966878), np.float64(0.21804589346900635), np.float64(0.6852870937597342)]

创建并训练 MultiModalPredictor

使用 Focal Loss 进行训练

我们通过将 "optim.loss_func" 设置为 "focal_loss" 来指定模型使用 focal loss。还有三个其他可选参数可以设置。

optim.focal_loss.alpha - 一个浮点数列表,是每个类别的损失权重,可用于平衡类别之间不均匀的样本分布。请注意,列表的 len 必须与训练数据集中的总类别数匹配。计算每个类别的 alpha 的一个好方法是使用其样本百分比的倒数。

optim.focal_loss.gamma - 一个浮点数,控制对难样本的关注程度。值越大意味着对难样本的关注越多。

optim.focal_loss.reduction - 如何聚合损失值。目前只能取 "mean""sum"

import uuid
from autogluon.multimodal import MultiModalPredictor

model_path = f"./tmp/{uuid.uuid4().hex}-automm_shopee_focal"

predictor = MultiModalPredictor(label="label", problem_type="multiclass", path=model_path)

predictor.fit(
    hyperparameters={
        "model.mmdet_image.checkpoint_name": "swin_tiny_patch4_window7_224",
        "env.num_gpus": 1,
        "optim.loss_func": "focal_loss",
        "optim.focal_loss.alpha": weights,  # shopee dataset has 4 classes.
        "optim.focal_loss.gamma": 1.0,
        "optim.focal_loss.reduction": "sum",
        "optim.max_epochs": 10,
    },
    train_data=imbalanced_train_data,
) 

predictor.evaluate(test_data, metrics=["acc"])
=================== System Info ===================
AutoGluon Version:  1.3.1b20250508
Python Version:     3.11.9
Operating System:   Linux
Platform Machine:   x86_64
Platform Version:   #1 SMP Wed Mar 12 14:53:59 UTC 2025
CPU Count:          8
Pytorch Version:    2.6.0+cu124
CUDA Version:       12.4
Memory Avail:       28.44 GB / 30.95 GB (91.9%)
Disk Space Avail:   166.39 GB / 255.99 GB (65.0%)
===================================================

AutoMM starts to create your model. ✨✨✨

To track the learning progress, you can open a terminal and launch Tensorboard:
    ```shell
    # Assume you have installed tensorboard
    tensorboard --logdir /home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/tmp/611cea89096f43f08a5e09e70a2fc8e2-automm_shopee_focal
    ```
Seed set to 0
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[4], line 8
      4 model_path = f"./tmp/{uuid.uuid4().hex}-automm_shopee_focal"
      6 predictor = MultiModalPredictor(label="label", problem_type="multiclass", path=model_path)
----> 8 predictor.fit(
      9     hyperparameters={
     10         "model.mmdet_image.checkpoint_name": "swin_tiny_patch4_window7_224",
     11         "env.num_gpus": 1,
     12         "optim.loss_func": "focal_loss",
     13         "optim.focal_loss.alpha": weights,  # shopee dataset has 4 classes.
     14         "optim.focal_loss.gamma": 1.0,
     15         "optim.focal_loss.reduction": "sum",
     16         "optim.max_epochs": 10,
     17     },
     18     train_data=imbalanced_train_data,
     19 ) 
     21 predictor.evaluate(test_data, metrics=["acc"])

File ~/autogluon/multimodal/src/autogluon/multimodal/predictor.py:540, in MultiModalPredictor.fit(self, train_data, presets, tuning_data, max_num_tuning_data, id_mappings, time_limit, save_path, hyperparameters, column_types, holdout_frac, teacher_predictor, seed, standalone, hyperparameter_tune_kwargs, clean_ckpts, predictions, labels, predictors)
    537     assert isinstance(predictors, list)
    538     learners = [ele if isinstance(ele, str) else ele._learner for ele in predictors]
--> 540 self._learner.fit(
    541     train_data=train_data,
    542     presets=presets,
    543     tuning_data=tuning_data,
    544     max_num_tuning_data=max_num_tuning_data,
    545     time_limit=time_limit,
    546     save_path=save_path,
    547     hyperparameters=hyperparameters,
    548     column_types=column_types,
    549     holdout_frac=holdout_frac,
    550     teacher_learner=teacher_learner,
    551     seed=seed,
    552     standalone=standalone,
    553     hyperparameter_tune_kwargs=hyperparameter_tune_kwargs,
    554     clean_ckpts=clean_ckpts,
    555     id_mappings=id_mappings,
    556     predictions=predictions,
    557     labels=labels,
    558     learners=learners,
    559 )
    561 return self

File ~/autogluon/multimodal/src/autogluon/multimodal/learners/base.py:665, in BaseLearner.fit(self, train_data, presets, tuning_data, time_limit, save_path, hyperparameters, column_types, holdout_frac, teacher_learner, seed, standalone, hyperparameter_tune_kwargs, clean_ckpts, **kwargs)
    658 self.fit_sanity_check()
    659 self.prepare_fit_args(
    660     time_limit=time_limit,
    661     seed=seed,
    662     standalone=standalone,
    663     clean_ckpts=clean_ckpts,
    664 )
--> 665 fit_returns = self.execute_fit()
    666 self.on_fit_end(
    667     training_start=training_start,
    668     strategy=fit_returns.get("strategy", None),
   (...)
    671     clean_ckpts=clean_ckpts,
    672 )
    674 return self

File ~/autogluon/multimodal/src/autogluon/multimodal/learners/base.py:577, in BaseLearner.execute_fit(self)
    575     return dict()
    576 else:
--> 577     attributes = self.fit_per_run(**self._fit_args)
    578     self.update_attributes(**attributes)  # only update attributes for non-HPO mode
    579     return attributes

File ~/autogluon/multimodal/src/autogluon/multimodal/learners/base.py:1292, in BaseLearner.fit_per_run(self, max_time, save_path, ckpt_path, resume, enable_progress_bar, seed, hyperparameters, advanced_hyperparameters, config, df_preprocessor, data_processors, model, standalone, clean_ckpts)
   1290 validation_metric, custom_metric_func = self.get_validation_metric_per_run()
   1291 mixup_active, mixup_func = self.get_mixup_func_per_run(config=config)
-> 1292 loss_func, aug_loss_func = self.get_loss_func_per_run(config=config, mixup_active=mixup_active)
   1293 model_postprocess_fn = self.get_model_postprocess_fn_per_run(loss_func=loss_func)
   1294 num_gpus, strategy = self.get_num_gpus_and_strategy_per_run(config=config)

File ~/autogluon/multimodal/src/autogluon/multimodal/learners/base.py:846, in BaseLearner.get_loss_func_per_run(self, config, mixup_active)
    845 def get_loss_func_per_run(self, config, mixup_active=None):
--> 846     loss_func = get_loss_func(
    847         problem_type=self._problem_type,
    848         mixup_active=mixup_active,
    849         loss_func_name=config.optim.loss_func,
    850         config=config.optim,
    851     )
    852     aug_loss_func = get_aug_loss_func(
    853         config=config.optim,
    854         problem_type=self._problem_type,
    855     )
    856     return loss_func, aug_loss_func

File ~/autogluon/multimodal/src/autogluon/multimodal/optim/losses/utils.py:63, in get_loss_func(problem_type, mixup_active, loss_func_name, config, **kwargs)
     61 else:
     62     if loss_func_name is not None and loss_func_name.lower() == "focal_loss":
---> 63         loss_func = FocalLoss(
     64             alpha=config.focal_loss.alpha,
     65             gamma=config.focal_loss.gamma,
     66             reduction=config.focal_loss.reduction,
     67         )
     68     else:
     69         loss_func = nn.CrossEntropyLoss(label_smoothing=config.label_smoothing)

File ~/autogluon/multimodal/src/autogluon/multimodal/optim/losses/focal_loss.py:49, in FocalLoss.__init__(self, alpha, gamma, reduction, eps)
     47         except:
     48             raise ValueError(f"{type(alpha)} {alpha} is not in a supported format.")
---> 49     alpha = torch.tensor(alpha)
     50 self.nll_loss = nn.NLLLoss(weight=alpha, reduction="none")

ValueError: too many dimensions 'str'

不使用 Focal Loss 进行训练

import uuid
from autogluon.multimodal import MultiModalPredictor

model_path = f"./tmp/{uuid.uuid4().hex}-automm_shopee_non_focal"

predictor2 = MultiModalPredictor(label="label", problem_type="multiclass", path=model_path)

predictor2.fit(
    hyperparameters={
        "model.mmdet_image.checkpoint_name": "swin_tiny_patch4_window7_224",
        "env.num_gpus": 1,
        "optim.max_epochs": 10,
    },
    train_data=imbalanced_train_data,
)

predictor2.evaluate(test_data, metrics=["acc"])

正如我们所见,使用 focal loss 的模型比不使用 focal loss 的模型能取得更好的性能。当你的数据不平衡时,尝试使用 focal loss 看看是否能提升性能!

引用

@misc{https://doi.org/10.48550/arxiv.1708.02002,
  doi = {10.48550/ARXIV.1708.02002},
  
  url = {https://arxiv.org/abs/1708.02002},
  
  author = {Lin, Tsung-Yi and Goyal, Priya and Girshick, Ross and He, Kaiming and Dollár, Piotr},
  
  keywords = {Computer Vision and Pattern Recognition (cs.CV), FOS: Computer and information sciences, FOS: Computer and information sciences},
  
  title = {Focal Loss for Dense Object Detection},
  
  publisher = {arXiv},
  
  year = {2017},
  
  copyright = {arXiv.org perpetual, non-exclusive license}
}