使用 AutoMM 处理类别不平衡 - Focal Loss¶
在本教程中,我们将介绍如何使用 AutoMM 包中的 focal loss 进行平衡训练。Focal loss 最早在这篇论文中提出,可用于平衡难易样本以及处理类别之间不均匀的样本分布。本教程演示了如何使用 focal loss。
创建数据集¶
在本教程中,我们使用 shopee 数据集进行演示。Shopee 数据集包含 4 个类别,训练集中每个类别有 200 个样本。
from autogluon.multimodal.utils.misc import shopee_dataset
download_dir = "./ag_automm_tutorial_imgcls_focalloss"
train_data, test_data = shopee_dataset(download_dir)
Downloading ./ag_automm_tutorial_imgcls_focalloss/file.zip from https://automl-mm-bench.s3.amazonaws.com/vision_datasets/shopee.zip...
0%| | 0.00/84.0M [00:00<?, ?iB/s]
9%|▉ | 7.88M/84.0M [00:00<00:00, 78.8MiB/s]
19%|█▉ | 15.8M/84.0M [00:00<00:01, 36.5MiB/s]
24%|██▍ | 20.6M/84.0M [00:00<00:01, 34.7MiB/s]
30%|██▉ | 25.2M/84.0M [00:00<00:01, 32.7MiB/s]
38%|███▊ | 32.2M/84.0M [00:00<00:01, 33.7MiB/s]
43%|████▎ | 35.7M/84.0M [00:01<00:01, 32.4MiB/s]
48%|████▊ | 40.2M/84.0M [00:01<00:01, 31.6MiB/s]
52%|█████▏ | 43.4M/84.0M [00:01<00:01, 29.4MiB/s]
58%|█████▊ | 48.5M/84.0M [00:01<00:01, 31.8MiB/s]
62%|██████▏ | 51.7M/84.0M [00:01<00:01, 30.7MiB/s]
70%|██████▉ | 58.4M/84.0M [00:01<00:00, 28.1MiB/s]
73%|███████▎ | 61.3M/84.0M [00:01<00:00, 25.3MiB/s]
80%|███████▉ | 67.1M/84.0M [00:02<00:00, 27.6MiB/s]
88%|████████▊ | 73.7M/84.0M [00:02<00:00, 33.5MiB/s]
92%|█████████▏| 77.2M/84.0M [00:02<00:00, 31.6MiB/s]
100%|█████████▉| 83.6M/84.0M [00:02<00:00, 38.8MiB/s]
100%|██████████| 84.0M/84.0M [00:02<00:00, 31.7MiB/s]
为了演示 Focal Loss 在不平衡训练数据上的有效性,我们人为地对 shopee 训练数据进行了下采样,以形成不平衡分布。
import numpy as np
import pandas as pd
ds = 1
imbalanced_train_data = []
for lb in range(4):
class_data = train_data[train_data.label == lb]
sample_index = np.random.choice(np.arange(len(class_data)), size=int(len(class_data) * ds), replace=False)
ds /= 3 # downsample 1/3 each time for each class
imbalanced_train_data.append(class_data.iloc[sample_index])
imbalanced_train_data = pd.concat(imbalanced_train_data)
print(imbalanced_train_data)
weights = []
for lb in range(4):
class_data = imbalanced_train_data[imbalanced_train_data.label == lb]
weights.append(1 / (class_data.shape[0] / imbalanced_train_data.shape[0]))
print(f"class {lb}: num samples {len(class_data)}")
weights = list(np.array(weights) / np.sum(weights))
print(weights)
image label
110 /home/ci/autogluon/docs/tutorials/multimodal/a... 0
48 /home/ci/autogluon/docs/tutorials/multimodal/a... 0
35 /home/ci/autogluon/docs/tutorials/multimodal/a... 0
53 /home/ci/autogluon/docs/tutorials/multimodal/a... 0
120 /home/ci/autogluon/docs/tutorials/multimodal/a... 0
.. ... ...
786 /home/ci/autogluon/docs/tutorials/multimodal/a... 3
755 /home/ci/autogluon/docs/tutorials/multimodal/a... 3
707 /home/ci/autogluon/docs/tutorials/multimodal/a... 3
687 /home/ci/autogluon/docs/tutorials/multimodal/a... 3
726 /home/ci/autogluon/docs/tutorials/multimodal/a... 3
[295 rows x 2 columns]
class 0: num samples 200
class 1: num samples 66
class 2: num samples 22
class 3: num samples 7
[np.float64(0.0239850482815907), np.float64(0.07268196448966878), np.float64(0.21804589346900635), np.float64(0.6852870937597342)]
创建并训练 MultiModalPredictor
¶
使用 Focal Loss 进行训练¶
我们通过将 "optim.loss_func"
设置为 "focal_loss"
来指定模型使用 focal loss。还有三个其他可选参数可以设置。
optim.focal_loss.alpha
- 一个浮点数列表,是每个类别的损失权重,可用于平衡类别之间不均匀的样本分布。请注意,列表的 len
必须与训练数据集中的总类别数匹配。计算每个类别的 alpha
的一个好方法是使用其样本百分比的倒数。
optim.focal_loss.gamma
- 一个浮点数,控制对难样本的关注程度。值越大意味着对难样本的关注越多。
optim.focal_loss.reduction
- 如何聚合损失值。目前只能取 "mean"
或 "sum"
。
import uuid
from autogluon.multimodal import MultiModalPredictor
model_path = f"./tmp/{uuid.uuid4().hex}-automm_shopee_focal"
predictor = MultiModalPredictor(label="label", problem_type="multiclass", path=model_path)
predictor.fit(
hyperparameters={
"model.mmdet_image.checkpoint_name": "swin_tiny_patch4_window7_224",
"env.num_gpus": 1,
"optim.loss_func": "focal_loss",
"optim.focal_loss.alpha": weights, # shopee dataset has 4 classes.
"optim.focal_loss.gamma": 1.0,
"optim.focal_loss.reduction": "sum",
"optim.max_epochs": 10,
},
train_data=imbalanced_train_data,
)
predictor.evaluate(test_data, metrics=["acc"])
=================== System Info ===================
AutoGluon Version: 1.3.1b20250508
Python Version: 3.11.9
Operating System: Linux
Platform Machine: x86_64
Platform Version: #1 SMP Wed Mar 12 14:53:59 UTC 2025
CPU Count: 8
Pytorch Version: 2.6.0+cu124
CUDA Version: 12.4
Memory Avail: 28.44 GB / 30.95 GB (91.9%)
Disk Space Avail: 166.39 GB / 255.99 GB (65.0%)
===================================================
AutoMM starts to create your model. ✨✨✨
To track the learning progress, you can open a terminal and launch Tensorboard:
```shell
# Assume you have installed tensorboard
tensorboard --logdir /home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/tmp/611cea89096f43f08a5e09e70a2fc8e2-automm_shopee_focal
```
Seed set to 0
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In[4], line 8
4 model_path = f"./tmp/{uuid.uuid4().hex}-automm_shopee_focal"
6 predictor = MultiModalPredictor(label="label", problem_type="multiclass", path=model_path)
----> 8 predictor.fit(
9 hyperparameters={
10 "model.mmdet_image.checkpoint_name": "swin_tiny_patch4_window7_224",
11 "env.num_gpus": 1,
12 "optim.loss_func": "focal_loss",
13 "optim.focal_loss.alpha": weights, # shopee dataset has 4 classes.
14 "optim.focal_loss.gamma": 1.0,
15 "optim.focal_loss.reduction": "sum",
16 "optim.max_epochs": 10,
17 },
18 train_data=imbalanced_train_data,
19 )
21 predictor.evaluate(test_data, metrics=["acc"])
File ~/autogluon/multimodal/src/autogluon/multimodal/predictor.py:540, in MultiModalPredictor.fit(self, train_data, presets, tuning_data, max_num_tuning_data, id_mappings, time_limit, save_path, hyperparameters, column_types, holdout_frac, teacher_predictor, seed, standalone, hyperparameter_tune_kwargs, clean_ckpts, predictions, labels, predictors)
537 assert isinstance(predictors, list)
538 learners = [ele if isinstance(ele, str) else ele._learner for ele in predictors]
--> 540 self._learner.fit(
541 train_data=train_data,
542 presets=presets,
543 tuning_data=tuning_data,
544 max_num_tuning_data=max_num_tuning_data,
545 time_limit=time_limit,
546 save_path=save_path,
547 hyperparameters=hyperparameters,
548 column_types=column_types,
549 holdout_frac=holdout_frac,
550 teacher_learner=teacher_learner,
551 seed=seed,
552 standalone=standalone,
553 hyperparameter_tune_kwargs=hyperparameter_tune_kwargs,
554 clean_ckpts=clean_ckpts,
555 id_mappings=id_mappings,
556 predictions=predictions,
557 labels=labels,
558 learners=learners,
559 )
561 return self
File ~/autogluon/multimodal/src/autogluon/multimodal/learners/base.py:665, in BaseLearner.fit(self, train_data, presets, tuning_data, time_limit, save_path, hyperparameters, column_types, holdout_frac, teacher_learner, seed, standalone, hyperparameter_tune_kwargs, clean_ckpts, **kwargs)
658 self.fit_sanity_check()
659 self.prepare_fit_args(
660 time_limit=time_limit,
661 seed=seed,
662 standalone=standalone,
663 clean_ckpts=clean_ckpts,
664 )
--> 665 fit_returns = self.execute_fit()
666 self.on_fit_end(
667 training_start=training_start,
668 strategy=fit_returns.get("strategy", None),
(...)
671 clean_ckpts=clean_ckpts,
672 )
674 return self
File ~/autogluon/multimodal/src/autogluon/multimodal/learners/base.py:577, in BaseLearner.execute_fit(self)
575 return dict()
576 else:
--> 577 attributes = self.fit_per_run(**self._fit_args)
578 self.update_attributes(**attributes) # only update attributes for non-HPO mode
579 return attributes
File ~/autogluon/multimodal/src/autogluon/multimodal/learners/base.py:1292, in BaseLearner.fit_per_run(self, max_time, save_path, ckpt_path, resume, enable_progress_bar, seed, hyperparameters, advanced_hyperparameters, config, df_preprocessor, data_processors, model, standalone, clean_ckpts)
1290 validation_metric, custom_metric_func = self.get_validation_metric_per_run()
1291 mixup_active, mixup_func = self.get_mixup_func_per_run(config=config)
-> 1292 loss_func, aug_loss_func = self.get_loss_func_per_run(config=config, mixup_active=mixup_active)
1293 model_postprocess_fn = self.get_model_postprocess_fn_per_run(loss_func=loss_func)
1294 num_gpus, strategy = self.get_num_gpus_and_strategy_per_run(config=config)
File ~/autogluon/multimodal/src/autogluon/multimodal/learners/base.py:846, in BaseLearner.get_loss_func_per_run(self, config, mixup_active)
845 def get_loss_func_per_run(self, config, mixup_active=None):
--> 846 loss_func = get_loss_func(
847 problem_type=self._problem_type,
848 mixup_active=mixup_active,
849 loss_func_name=config.optim.loss_func,
850 config=config.optim,
851 )
852 aug_loss_func = get_aug_loss_func(
853 config=config.optim,
854 problem_type=self._problem_type,
855 )
856 return loss_func, aug_loss_func
File ~/autogluon/multimodal/src/autogluon/multimodal/optim/losses/utils.py:63, in get_loss_func(problem_type, mixup_active, loss_func_name, config, **kwargs)
61 else:
62 if loss_func_name is not None and loss_func_name.lower() == "focal_loss":
---> 63 loss_func = FocalLoss(
64 alpha=config.focal_loss.alpha,
65 gamma=config.focal_loss.gamma,
66 reduction=config.focal_loss.reduction,
67 )
68 else:
69 loss_func = nn.CrossEntropyLoss(label_smoothing=config.label_smoothing)
File ~/autogluon/multimodal/src/autogluon/multimodal/optim/losses/focal_loss.py:49, in FocalLoss.__init__(self, alpha, gamma, reduction, eps)
47 except:
48 raise ValueError(f"{type(alpha)} {alpha} is not in a supported format.")
---> 49 alpha = torch.tensor(alpha)
50 self.nll_loss = nn.NLLLoss(weight=alpha, reduction="none")
ValueError: too many dimensions 'str'
不使用 Focal Loss 进行训练¶
import uuid
from autogluon.multimodal import MultiModalPredictor
model_path = f"./tmp/{uuid.uuid4().hex}-automm_shopee_non_focal"
predictor2 = MultiModalPredictor(label="label", problem_type="multiclass", path=model_path)
predictor2.fit(
hyperparameters={
"model.mmdet_image.checkpoint_name": "swin_tiny_patch4_window7_224",
"env.num_gpus": 1,
"optim.max_epochs": 10,
},
train_data=imbalanced_train_data,
)
predictor2.evaluate(test_data, metrics=["acc"])
正如我们所见,使用 focal loss 的模型比不使用 focal loss 的模型能取得更好的性能。当你的数据不平衡时,尝试使用 focal loss 看看是否能提升性能!
引用¶
@misc{https://doi.org/10.48550/arxiv.1708.02002,
doi = {10.48550/ARXIV.1708.02002},
url = {https://arxiv.org/abs/1708.02002},
author = {Lin, Tsung-Yi and Goyal, Priya and Girshick, Ross and He, Kaiming and Dollár, Piotr},
keywords = {Computer Vision and Pattern Recognition (cs.CV), FOS: Computer and information sciences, FOS: Computer and information sciences},
title = {Focal Loss for Dense Object Detection},
publisher = {arXiv},
year = {2017},
copyright = {arXiv.org perpetual, non-exclusive license}
}