AutoMM 中的超参数优化

Open In Colab Open In SageMaker Studio Lab

超参数优化 (HPO) 是一种帮助解决调优机器学习模型超参数挑战的方法。机器学习算法有许多复杂的超参数,它们产生巨大的搜索空间,而深度学习方法中的搜索空间甚至比传统机器学习算法更大。在巨大的搜索空间中进行调优是一个艰巨的挑战,但 AutoMM 提供了多种选项,供您根据您的领域知识和计算资源限制来指导拟合过程。

创建图像数据集

在本教程中,我们将再次使用来自 Kaggle 的 Shopee-IET 数据集的子集进行演示。每张图像都包含一件衣物,相应的标签指定了其服装类别。我们的数据集子集包含以下可能的标签:BabyPantsBabyShirtwomencasualshoeswomenchiffontop

我们可以通过自动下载 URL 数据来加载数据集

import warnings
warnings.filterwarnings('ignore')
from datetime import datetime

from autogluon.multimodal.utils.misc import shopee_dataset
download_dir = './ag_automm_tutorial_hpo'
train_data, test_data = shopee_dataset(download_dir)
train_data = train_data.sample(frac=0.5)
print(train_data)
Downloading ./ag_automm_tutorial_hpo/file.zip from https://automl-mm-bench.s3.amazonaws.com/vision_datasets/shopee.zip...
                                                 image  label
406  /home/ci/autogluon/docs/tutorials/multimodal/a...      2
557  /home/ci/autogluon/docs/tutorials/multimodal/a...      2
654  /home/ci/autogluon/docs/tutorials/multimodal/a...      3
96   /home/ci/autogluon/docs/tutorials/multimodal/a...      0
413  /home/ci/autogluon/docs/tutorials/multimodal/a...      2
..                                                 ...    ...
146  /home/ci/autogluon/docs/tutorials/multimodal/a...      0
732  /home/ci/autogluon/docs/tutorials/multimodal/a...      3
502  /home/ci/autogluon/docs/tutorials/multimodal/a...      2
434  /home/ci/autogluon/docs/tutorials/multimodal/a...      2
238  /home/ci/autogluon/docs/tutorials/multimodal/a...      1

[400 rows x 2 columns]
  0%|          | 0.00/84.0M [00:00<?, ?iB/s]
 10%|▉         | 8.38M/84.0M [00:00<00:00, 77.3MiB/s]
 22%|██▏       | 18.8M/84.0M [00:00<00:00, 92.7MiB/s]
 35%|███▌      | 29.5M/84.0M [00:00<00:00, 99.2MiB/s]
 48%|████▊     | 40.3M/84.0M [00:00<00:00, 103MiB/s]
 60%|██████    | 50.6M/84.0M [00:00<00:00, 102MiB/s]
 72%|███████▏  | 60.8M/84.0M [00:00<00:00, 79.7MiB/s]
 85%|████████▍ | 71.3M/84.0M [00:00<00:00, 86.6MiB/s]
 98%|█████████▊| 82.1M/84.0M [00:00<00:00, 92.4MiB/s]
100%|██████████| 84.0M/84.0M [00:00<00:00, 89.7MiB/s]

该数据集总共有 400 个数据点。image 列存储实际图像的路径,而 label 列表示标签类别。

常规模型拟合

回顾一下,如果我们使用 AutoGluon 预定义的默认设置,只需三行代码即可使用 MultiModalPredictor 拟合模型

from autogluon.multimodal import MultiModalPredictor
predictor_regular = MultiModalPredictor(label="label")
start_time = datetime.now()
predictor_regular.fit(
    train_data=train_data,
    hyperparameters = {"model.timm_image.checkpoint_name": "ghostnet_100"}
)
end_time = datetime.now()
elapsed_seconds = (end_time - start_time).total_seconds()
elapsed_min = divmod(elapsed_seconds, 60)
print("Total fitting time: ", f"{int(elapsed_min[0])}m{int(elapsed_min[1])}s")
Total fitting time:  1m0s
No path specified. Models will be saved in: "AutogluonModels/ag-20250508_212209"
=================== System Info ===================
AutoGluon Version:  1.3.1b20250508
Python Version:     3.11.9
Operating System:   Linux
Platform Machine:   x86_64
Platform Version:   #1 SMP Wed Mar 12 14:53:59 UTC 2025
CPU Count:          8
Pytorch Version:    2.6.0+cu124
CUDA Version:       12.4
Memory Avail:       28.44 GB / 30.95 GB (91.9%)
Disk Space Avail:   166.20 GB / 255.99 GB (64.9%)
===================================================
AutoGluon infers your prediction problem is: 'multiclass' (because dtype of label-column == int, but few unique label-values observed).
	4 unique label values:  [np.int64(2), np.int64(3), np.int64(0), np.int64(1)]
	If 'multiclass' is not the correct problem_type, please manually specify the problem_type parameter during Predictor init (You may specify problem_type as one of: ['binary', 'multiclass', 'regression', 'quantile'])

AutoMM starts to create your model. ✨✨✨

To track the learning progress, you can open a terminal and launch Tensorboard:
    ```shell
    # Assume you have installed tensorboard
    tensorboard --logdir /home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20250508_212209
    ```
Seed set to 0
GPU Count: 1
GPU Count to be Used: 1
Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name              | Type                            | Params | Mode 
------------------------------------------------------------------------------
0 | model             | TimmAutoModelForImagePrediction | 3.9 M  | train
1 | validation_metric | MulticlassAccuracy              | 0      | train
2 | loss_func         | CrossEntropyLoss                | 0      | train
------------------------------------------------------------------------------
3.9 M     Trainable params
0         Non-trainable params
3.9 M     Total params
15.627    Total estimated model params size (MB)
418       Modules in train mode
0         Modules in eval mode
Epoch 0, global step 1: 'val_accuracy' reached 0.31250 (best 0.31250), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20250508_212209/epoch=0-step=1.ckpt' as top 3
Epoch 0, global step 3: 'val_accuracy' reached 0.33750 (best 0.33750), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20250508_212209/epoch=0-step=3.ckpt' as top 3
Epoch 1, global step 4: 'val_accuracy' reached 0.35000 (best 0.35000), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20250508_212209/epoch=1-step=4.ckpt' as top 3
Epoch 1, global step 6: 'val_accuracy' reached 0.46250 (best 0.46250), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20250508_212209/epoch=1-step=6.ckpt' as top 3
Epoch 2, global step 7: 'val_accuracy' reached 0.52500 (best 0.52500), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20250508_212209/epoch=2-step=7.ckpt' as top 3
Epoch 2, global step 9: 'val_accuracy' reached 0.60000 (best 0.60000), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20250508_212209/epoch=2-step=9.ckpt' as top 3
Epoch 3, global step 10: 'val_accuracy' reached 0.62500 (best 0.62500), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20250508_212209/epoch=3-step=10.ckpt' as top 3
Epoch 3, global step 12: 'val_accuracy' reached 0.66250 (best 0.66250), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20250508_212209/epoch=3-step=12.ckpt' as top 3
Epoch 4, global step 13: 'val_accuracy' reached 0.71250 (best 0.71250), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20250508_212209/epoch=4-step=13.ckpt' as top 3
Epoch 4, global step 15: 'val_accuracy' reached 0.67500 (best 0.71250), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20250508_212209/epoch=4-step=15.ckpt' as top 3
Epoch 5, global step 16: 'val_accuracy' reached 0.68750 (best 0.71250), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20250508_212209/epoch=5-step=16.ckpt' as top 3
Epoch 5, global step 18: 'val_accuracy' reached 0.71250 (best 0.71250), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20250508_212209/epoch=5-step=18.ckpt' as top 3
Epoch 6, global step 19: 'val_accuracy' reached 0.72500 (best 0.72500), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20250508_212209/epoch=6-step=19.ckpt' as top 3
Epoch 6, global step 21: 'val_accuracy' reached 0.72500 (best 0.72500), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20250508_212209/epoch=6-step=21.ckpt' as top 3
Epoch 7, global step 22: 'val_accuracy' reached 0.75000 (best 0.75000), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20250508_212209/epoch=7-step=22.ckpt' as top 3
Epoch 7, global step 24: 'val_accuracy' reached 0.73750 (best 0.75000), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20250508_212209/epoch=7-step=24.ckpt' as top 3
Epoch 8, global step 25: 'val_accuracy' reached 0.75000 (best 0.75000), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20250508_212209/epoch=8-step=25.ckpt' as top 3
Epoch 8, global step 27: 'val_accuracy' reached 0.75000 (best 0.75000), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20250508_212209/epoch=8-step=27.ckpt' as top 3
Epoch 9, global step 28: 'val_accuracy' reached 0.76250 (best 0.76250), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20250508_212209/epoch=9-step=28.ckpt' as top 3
Epoch 9, global step 30: 'val_accuracy' was not in top 3
Epoch 10, global step 31: 'val_accuracy' was not in top 3
Epoch 10, global step 33: 'val_accuracy' was not in top 3
Epoch 11, global step 34: 'val_accuracy' reached 0.77500 (best 0.77500), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20250508_212209/epoch=11-step=34.ckpt' as top 3
Epoch 11, global step 36: 'val_accuracy' was not in top 3
Epoch 12, global step 37: 'val_accuracy' reached 0.77500 (best 0.77500), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20250508_212209/epoch=12-step=37.ckpt' as top 3
Epoch 12, global step 39: 'val_accuracy' was not in top 3
Epoch 13, global step 40: 'val_accuracy' was not in top 3
Epoch 13, global step 42: 'val_accuracy' was not in top 3
Epoch 14, global step 43: 'val_accuracy' reached 0.77500 (best 0.77500), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20250508_212209/epoch=14-step=43.ckpt' as top 3
Epoch 14, global step 45: 'val_accuracy' was not in top 3
Epoch 15, global step 46: 'val_accuracy' was not in top 3
Epoch 15, global step 48: 'val_accuracy' was not in top 3
Epoch 16, global step 49: 'val_accuracy' was not in top 3
Start to fuse 3 checkpoints via the greedy soup algorithm.
Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
AutoMM has created your model. 🎉🎉🎉

To load the model, use the code below:
    ```python
    from autogluon.multimodal import MultiModalPredictor
    predictor = MultiModalPredictor.load("/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20250508_212209")
    ```

If you are not satisfied with the model, try to increase the training time, 
adjust the hyperparameters (https://autogluon.cn/stable/tutorials/multimodal/advanced_topics/customization.html),
or post issues on GitHub (https://github.com/autogluon/autogluon/issues).

让我们看看拟合模型的测试准确率

scores = predictor_regular.evaluate(test_data, metrics=["accuracy"])
print('Top-1 test acc: %.3f' % scores["accuracy"])
Top-1 test acc: 0.738
Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.

在模型拟合期间使用 HPO

如果您想对拟合过程有更多控制,您可以通过在 hyperparameterhyperparameter_tune_kwargs 中添加更多选项,来指定 MultiModalPredictor 中超参数优化 (HPO) 的各种选项。

在 MultiModalPredictor 中,我们有几个选项。我们在后端使用 Ray Tune tune 库,因此需要传入 Tune 搜索空间AutoGluon 搜索空间,它们将被转换为 Tune 搜索空间。

  1. 为神经网络训练定义各种 hyperparameter 值的搜索空间

    hyperparameters = {
            "optim.lr": tune.uniform(0.00005, 0.005),
            "optim.optim_type": tune.choice(["adamw", "sgd"]),
            "optim.max_epochs": tune.choice(["10", "20"]), 
            "model.timm_image.checkpoint_name": tune.choice(["swin_base_patch4_window7_224", "convnext_base_in22ft1k"])
            }
    

    这只是一个示例,并非完整列表。您可以在 定制 AutoMM 中找到完整的支持列表

  1. 使用 hyperparameter_tune_kwargs 定义 HPO 的搜索策略。您可以传入一个字符串或初始化一个 ray.tune.schedulers.TrialScheduler 对象。

    a. 指定如何在选择的超参数空间中搜索 (支持 `random` 和 `bayes`)
    "searcher": "bayes"
    
    b. 指定如何安排作业在特定超参数配置下训练网络 (支持 `FIFO` 和 `ASHA`)
    "scheduler": "ASHA"
    
    c. 您希望执行的 HPO 试验次数
    "num_trials": 20
    
    d. 每个试验保留在磁盘上的检查点数量,详情请参阅 Ray 文档。必须 >= 1。(默认值为 3)
    "num_to_keep": 3
    

让我们尝试使用不同学习率和主干模型的组合进行 HPO

from ray import tune

predictor_hpo = MultiModalPredictor(label="label")

hyperparameters = {
            "optim.lr": tune.uniform(0.00005, 0.001),
            "model.timm_image.checkpoint_name": tune.choice(["ghostnet_100",
                                                             "mobilenetv3_large_100"])
}
hyperparameter_tune_kwargs = {
    "searcher": "bayes", # random
    "scheduler": "ASHA",
    "num_trials": 2,
    "num_to_keep": 3,
}
start_time_hpo = datetime.now()
predictor_hpo.fit(
        train_data=train_data,
        hyperparameters=hyperparameters,
        hyperparameter_tune_kwargs=hyperparameter_tune_kwargs,
    )
end_time_hpo = datetime.now()
elapsed_seconds_hpo = (end_time_hpo - start_time_hpo).total_seconds()
elapsed_min_hpo = divmod(elapsed_seconds_hpo, 60)
print("Total fitting time: ", f"{int(elapsed_min_hpo[0])}m{int(elapsed_min_hpo[1])}s")
No path specified. Models will be saved in: "AutogluonModels/ag-20250508_212310"
=================== System Info ===================
AutoGluon Version:  1.3.1b20250508
Python Version:     3.11.9
Operating System:   Linux
Platform Machine:   x86_64
Platform Version:   #1 SMP Wed Mar 12 14:53:59 UTC 2025
CPU Count:          8
Pytorch Version:    2.6.0+cu124
CUDA Version:       12.4
Memory Avail:       27.51 GB / 30.95 GB (88.9%)
Disk Space Avail:   166.20 GB / 255.99 GB (64.9%)
===================================================
AutoGluon infers your prediction problem is: 'multiclass' (because dtype of label-column == int, but few unique label-values observed).
	4 unique label values:  [np.int64(2), np.int64(3), np.int64(0), np.int64(1)]
	If 'multiclass' is not the correct problem_type, please manually specify the problem_type parameter during Predictor init (You may specify problem_type as one of: ['binary', 'multiclass', 'regression', 'quantile'])
/home/ci/opt/venv/lib/python3.11/site-packages/ray/tune/impl/tuner_internal.py:144: RayDeprecationWarning: The `RunConfig` class should be imported from `ray.tune` when passing it to the Tuner. Please update your imports. See this issue for more context and migration options: https://github.com/ray-project/ray/issues/49454. Disable these warnings by setting the environment variable: RAY_TRAIN_ENABLE_V2_MIGRATION_WARNINGS=0
  _log_deprecation_warning(
Removing non-optimal trials and only keep the best one.
Start to fuse 3 checkpoints via the greedy soup algorithm.
Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
AutoMM has created your model. 🎉🎉🎉

To load the model, use the code below:
    ```python
    from autogluon.multimodal import MultiModalPredictor
    predictor = MultiModalPredictor.load("/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20250508_212310")
    ```

If you are not satisfied with the model, try to increase the training time, 
adjust the hyperparameters (https://autogluon.cn/stable/tutorials/multimodal/advanced_topics/customization.html),
or post issues on GitHub (https://github.com/autogluon/autogluon/issues).

Tune 状态

当前时间2025-05-08 21:24:51
已运行00:01:34.82
内存5.2/30.9 GiB

系统信息

使用 AsyncHyperBand: num_stopped=0
Bracket: Iter 4096.000: None | Iter 1024.000: None | Iter 256.000: None | Iter 64.000: None | Iter 16.000: 0.800000011920929 | Iter 4.000: 0.6156249791383743 | Iter 1.000: 0.26249998807907104
逻辑资源使用情况: 8.0/8 CPUs, 1.0/1 GPUs (0.0/1.0 accelerator_type:T4)

试验状态

试验名称状态位置模型名称model.timm_image.che ckpoint_nameoptim.lr迭代总时间 (秒)验证准确率
13925b61已终止10.0.1.51:5412('timm_image', _9d40mobilenetv3_lar_abf06.48954e-05 38 50.5585 0.6625
0960fed9已终止10.0.1.51:5617('timm_image', _00c0mobilenetv3_lar_abf00.000833069 18 24.3515 0.85

试验进度

试验名称should_checkpoint验证准确率
0960fed9True 0.85
13925b61True 0.6625
Total fitting time:  1m43s

让我们看看 HPO 后拟合模型的测试准确率

scores_hpo = predictor_hpo.evaluate(test_data, metrics=["accuracy"])
print('Top-1 test acc: %.3f' % scores_hpo["accuracy"])
Top-1 test acc: 0.875
Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.

从训练日志中,您应该能看到当前最佳试验,如下所示

Current best trial: 47aef96a with val_accuracy=0.862500011920929 and parameters={'optim.lr': 0.0007195214018085505, 'model.timm_image.checkpoint_name': 'ghostnet_100'}

在我们简单的 2 次试验 HPO 运行后,通过搜索不同的学习率和模型,我们获得了比上一节提供的即用型解决方案更好的测试准确率。HPO 有助于选择具有最高验证准确率的超参数组合。

其他示例

您可以访问 AutoMM 示例 查看更多关于 AutoMM 的示例。

定制

要了解如何定制 AutoMM,请参阅 定制 AutoMM