AutoMM 中的超参数优化¶
超参数优化 (HPO) 是一种帮助解决调优机器学习模型超参数挑战的方法。机器学习算法有许多复杂的超参数,它们产生巨大的搜索空间,而深度学习方法中的搜索空间甚至比传统机器学习算法更大。在巨大的搜索空间中进行调优是一个艰巨的挑战,但 AutoMM 提供了多种选项,供您根据您的领域知识和计算资源限制来指导拟合过程。
创建图像数据集¶
在本教程中,我们将再次使用来自 Kaggle 的 Shopee-IET 数据集的子集进行演示。每张图像都包含一件衣物,相应的标签指定了其服装类别。我们的数据集子集包含以下可能的标签:BabyPants
、BabyShirt
、womencasualshoes
、womenchiffontop
。
我们可以通过自动下载 URL 数据来加载数据集
import warnings
warnings.filterwarnings('ignore')
from datetime import datetime
from autogluon.multimodal.utils.misc import shopee_dataset
download_dir = './ag_automm_tutorial_hpo'
train_data, test_data = shopee_dataset(download_dir)
train_data = train_data.sample(frac=0.5)
print(train_data)
Downloading ./ag_automm_tutorial_hpo/file.zip from https://automl-mm-bench.s3.amazonaws.com/vision_datasets/shopee.zip...
image label
406 /home/ci/autogluon/docs/tutorials/multimodal/a... 2
557 /home/ci/autogluon/docs/tutorials/multimodal/a... 2
654 /home/ci/autogluon/docs/tutorials/multimodal/a... 3
96 /home/ci/autogluon/docs/tutorials/multimodal/a... 0
413 /home/ci/autogluon/docs/tutorials/multimodal/a... 2
.. ... ...
146 /home/ci/autogluon/docs/tutorials/multimodal/a... 0
732 /home/ci/autogluon/docs/tutorials/multimodal/a... 3
502 /home/ci/autogluon/docs/tutorials/multimodal/a... 2
434 /home/ci/autogluon/docs/tutorials/multimodal/a... 2
238 /home/ci/autogluon/docs/tutorials/multimodal/a... 1
[400 rows x 2 columns]
0%| | 0.00/84.0M [00:00<?, ?iB/s]
10%|▉ | 8.38M/84.0M [00:00<00:00, 77.3MiB/s]
22%|██▏ | 18.8M/84.0M [00:00<00:00, 92.7MiB/s]
35%|███▌ | 29.5M/84.0M [00:00<00:00, 99.2MiB/s]
48%|████▊ | 40.3M/84.0M [00:00<00:00, 103MiB/s]
60%|██████ | 50.6M/84.0M [00:00<00:00, 102MiB/s]
72%|███████▏ | 60.8M/84.0M [00:00<00:00, 79.7MiB/s]
85%|████████▍ | 71.3M/84.0M [00:00<00:00, 86.6MiB/s]
98%|█████████▊| 82.1M/84.0M [00:00<00:00, 92.4MiB/s]
100%|██████████| 84.0M/84.0M [00:00<00:00, 89.7MiB/s]
该数据集总共有 400 个数据点。image
列存储实际图像的路径,而 label
列表示标签类别。
常规模型拟合¶
回顾一下,如果我们使用 AutoGluon 预定义的默认设置,只需三行代码即可使用 MultiModalPredictor
拟合模型
from autogluon.multimodal import MultiModalPredictor
predictor_regular = MultiModalPredictor(label="label")
start_time = datetime.now()
predictor_regular.fit(
train_data=train_data,
hyperparameters = {"model.timm_image.checkpoint_name": "ghostnet_100"}
)
end_time = datetime.now()
elapsed_seconds = (end_time - start_time).total_seconds()
elapsed_min = divmod(elapsed_seconds, 60)
print("Total fitting time: ", f"{int(elapsed_min[0])}m{int(elapsed_min[1])}s")
Total fitting time: 1m0s
No path specified. Models will be saved in: "AutogluonModels/ag-20250508_212209"
=================== System Info ===================
AutoGluon Version: 1.3.1b20250508
Python Version: 3.11.9
Operating System: Linux
Platform Machine: x86_64
Platform Version: #1 SMP Wed Mar 12 14:53:59 UTC 2025
CPU Count: 8
Pytorch Version: 2.6.0+cu124
CUDA Version: 12.4
Memory Avail: 28.44 GB / 30.95 GB (91.9%)
Disk Space Avail: 166.20 GB / 255.99 GB (64.9%)
===================================================
AutoGluon infers your prediction problem is: 'multiclass' (because dtype of label-column == int, but few unique label-values observed).
4 unique label values: [np.int64(2), np.int64(3), np.int64(0), np.int64(1)]
If 'multiclass' is not the correct problem_type, please manually specify the problem_type parameter during Predictor init (You may specify problem_type as one of: ['binary', 'multiclass', 'regression', 'quantile'])
AutoMM starts to create your model. ✨✨✨
To track the learning progress, you can open a terminal and launch Tensorboard:
```shell
# Assume you have installed tensorboard
tensorboard --logdir /home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20250508_212209
```
Seed set to 0
GPU Count: 1
GPU Count to be Used: 1
Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
| Name | Type | Params | Mode
------------------------------------------------------------------------------
0 | model | TimmAutoModelForImagePrediction | 3.9 M | train
1 | validation_metric | MulticlassAccuracy | 0 | train
2 | loss_func | CrossEntropyLoss | 0 | train
------------------------------------------------------------------------------
3.9 M Trainable params
0 Non-trainable params
3.9 M Total params
15.627 Total estimated model params size (MB)
418 Modules in train mode
0 Modules in eval mode
Epoch 0, global step 1: 'val_accuracy' reached 0.31250 (best 0.31250), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20250508_212209/epoch=0-step=1.ckpt' as top 3
Epoch 0, global step 3: 'val_accuracy' reached 0.33750 (best 0.33750), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20250508_212209/epoch=0-step=3.ckpt' as top 3
Epoch 1, global step 4: 'val_accuracy' reached 0.35000 (best 0.35000), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20250508_212209/epoch=1-step=4.ckpt' as top 3
Epoch 1, global step 6: 'val_accuracy' reached 0.46250 (best 0.46250), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20250508_212209/epoch=1-step=6.ckpt' as top 3
Epoch 2, global step 7: 'val_accuracy' reached 0.52500 (best 0.52500), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20250508_212209/epoch=2-step=7.ckpt' as top 3
Epoch 2, global step 9: 'val_accuracy' reached 0.60000 (best 0.60000), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20250508_212209/epoch=2-step=9.ckpt' as top 3
Epoch 3, global step 10: 'val_accuracy' reached 0.62500 (best 0.62500), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20250508_212209/epoch=3-step=10.ckpt' as top 3
Epoch 3, global step 12: 'val_accuracy' reached 0.66250 (best 0.66250), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20250508_212209/epoch=3-step=12.ckpt' as top 3
Epoch 4, global step 13: 'val_accuracy' reached 0.71250 (best 0.71250), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20250508_212209/epoch=4-step=13.ckpt' as top 3
Epoch 4, global step 15: 'val_accuracy' reached 0.67500 (best 0.71250), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20250508_212209/epoch=4-step=15.ckpt' as top 3
Epoch 5, global step 16: 'val_accuracy' reached 0.68750 (best 0.71250), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20250508_212209/epoch=5-step=16.ckpt' as top 3
Epoch 5, global step 18: 'val_accuracy' reached 0.71250 (best 0.71250), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20250508_212209/epoch=5-step=18.ckpt' as top 3
Epoch 6, global step 19: 'val_accuracy' reached 0.72500 (best 0.72500), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20250508_212209/epoch=6-step=19.ckpt' as top 3
Epoch 6, global step 21: 'val_accuracy' reached 0.72500 (best 0.72500), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20250508_212209/epoch=6-step=21.ckpt' as top 3
Epoch 7, global step 22: 'val_accuracy' reached 0.75000 (best 0.75000), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20250508_212209/epoch=7-step=22.ckpt' as top 3
Epoch 7, global step 24: 'val_accuracy' reached 0.73750 (best 0.75000), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20250508_212209/epoch=7-step=24.ckpt' as top 3
Epoch 8, global step 25: 'val_accuracy' reached 0.75000 (best 0.75000), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20250508_212209/epoch=8-step=25.ckpt' as top 3
Epoch 8, global step 27: 'val_accuracy' reached 0.75000 (best 0.75000), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20250508_212209/epoch=8-step=27.ckpt' as top 3
Epoch 9, global step 28: 'val_accuracy' reached 0.76250 (best 0.76250), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20250508_212209/epoch=9-step=28.ckpt' as top 3
Epoch 9, global step 30: 'val_accuracy' was not in top 3
Epoch 10, global step 31: 'val_accuracy' was not in top 3
Epoch 10, global step 33: 'val_accuracy' was not in top 3
Epoch 11, global step 34: 'val_accuracy' reached 0.77500 (best 0.77500), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20250508_212209/epoch=11-step=34.ckpt' as top 3
Epoch 11, global step 36: 'val_accuracy' was not in top 3
Epoch 12, global step 37: 'val_accuracy' reached 0.77500 (best 0.77500), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20250508_212209/epoch=12-step=37.ckpt' as top 3
Epoch 12, global step 39: 'val_accuracy' was not in top 3
Epoch 13, global step 40: 'val_accuracy' was not in top 3
Epoch 13, global step 42: 'val_accuracy' was not in top 3
Epoch 14, global step 43: 'val_accuracy' reached 0.77500 (best 0.77500), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20250508_212209/epoch=14-step=43.ckpt' as top 3
Epoch 14, global step 45: 'val_accuracy' was not in top 3
Epoch 15, global step 46: 'val_accuracy' was not in top 3
Epoch 15, global step 48: 'val_accuracy' was not in top 3
Epoch 16, global step 49: 'val_accuracy' was not in top 3
Start to fuse 3 checkpoints via the greedy soup algorithm.
Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
AutoMM has created your model. 🎉🎉🎉
To load the model, use the code below:
```python
from autogluon.multimodal import MultiModalPredictor
predictor = MultiModalPredictor.load("/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20250508_212209")
```
If you are not satisfied with the model, try to increase the training time,
adjust the hyperparameters (https://autogluon.cn/stable/tutorials/multimodal/advanced_topics/customization.html),
or post issues on GitHub (https://github.com/autogluon/autogluon/issues).
让我们看看拟合模型的测试准确率
scores = predictor_regular.evaluate(test_data, metrics=["accuracy"])
print('Top-1 test acc: %.3f' % scores["accuracy"])
Top-1 test acc: 0.738
Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
在模型拟合期间使用 HPO¶
如果您想对拟合过程有更多控制,您可以通过在 hyperparameter
和 hyperparameter_tune_kwargs
中添加更多选项,来指定 MultiModalPredictor
中超参数优化 (HPO) 的各种选项。
在 MultiModalPredictor 中,我们有几个选项。我们在后端使用 Ray Tune tune
库,因此需要传入 Tune 搜索空间 或 AutoGluon 搜索空间,它们将被转换为 Tune 搜索空间。
为神经网络训练定义各种
hyperparameter
值的搜索空间
hyperparameters = {
"optim.lr": tune.uniform(0.00005, 0.005),
"optim.optim_type": tune.choice(["adamw", "sgd"]),
"optim.max_epochs": tune.choice(["10", "20"]),
"model.timm_image.checkpoint_name": tune.choice(["swin_base_patch4_window7_224", "convnext_base_in22ft1k"])
}
这只是一个示例,并非完整列表。您可以在 定制 AutoMM 中找到完整的支持列表
使用
hyperparameter_tune_kwargs
定义 HPO 的搜索策略。您可以传入一个字符串或初始化一个ray.tune.schedulers.TrialScheduler
对象。
- a. 指定如何在选择的超参数空间中搜索 (支持 `random` 和 `bayes`)
"searcher": "bayes"
- b. 指定如何安排作业在特定超参数配置下训练网络 (支持 `FIFO` 和 `ASHA`)
"scheduler": "ASHA"
- c. 您希望执行的 HPO 试验次数
"num_trials": 20
- d. 每个试验保留在磁盘上的检查点数量,详情请参阅 Ray 文档。必须 >= 1。(默认值为 3)
"num_to_keep": 3
让我们尝试使用不同学习率和主干模型的组合进行 HPO
from ray import tune
predictor_hpo = MultiModalPredictor(label="label")
hyperparameters = {
"optim.lr": tune.uniform(0.00005, 0.001),
"model.timm_image.checkpoint_name": tune.choice(["ghostnet_100",
"mobilenetv3_large_100"])
}
hyperparameter_tune_kwargs = {
"searcher": "bayes", # random
"scheduler": "ASHA",
"num_trials": 2,
"num_to_keep": 3,
}
start_time_hpo = datetime.now()
predictor_hpo.fit(
train_data=train_data,
hyperparameters=hyperparameters,
hyperparameter_tune_kwargs=hyperparameter_tune_kwargs,
)
end_time_hpo = datetime.now()
elapsed_seconds_hpo = (end_time_hpo - start_time_hpo).total_seconds()
elapsed_min_hpo = divmod(elapsed_seconds_hpo, 60)
print("Total fitting time: ", f"{int(elapsed_min_hpo[0])}m{int(elapsed_min_hpo[1])}s")
No path specified. Models will be saved in: "AutogluonModels/ag-20250508_212310"
=================== System Info ===================
AutoGluon Version: 1.3.1b20250508
Python Version: 3.11.9
Operating System: Linux
Platform Machine: x86_64
Platform Version: #1 SMP Wed Mar 12 14:53:59 UTC 2025
CPU Count: 8
Pytorch Version: 2.6.0+cu124
CUDA Version: 12.4
Memory Avail: 27.51 GB / 30.95 GB (88.9%)
Disk Space Avail: 166.20 GB / 255.99 GB (64.9%)
===================================================
AutoGluon infers your prediction problem is: 'multiclass' (because dtype of label-column == int, but few unique label-values observed).
4 unique label values: [np.int64(2), np.int64(3), np.int64(0), np.int64(1)]
If 'multiclass' is not the correct problem_type, please manually specify the problem_type parameter during Predictor init (You may specify problem_type as one of: ['binary', 'multiclass', 'regression', 'quantile'])
/home/ci/opt/venv/lib/python3.11/site-packages/ray/tune/impl/tuner_internal.py:144: RayDeprecationWarning: The `RunConfig` class should be imported from `ray.tune` when passing it to the Tuner. Please update your imports. See this issue for more context and migration options: https://github.com/ray-project/ray/issues/49454. Disable these warnings by setting the environment variable: RAY_TRAIN_ENABLE_V2_MIGRATION_WARNINGS=0
_log_deprecation_warning(
Removing non-optimal trials and only keep the best one.
Start to fuse 3 checkpoints via the greedy soup algorithm.
Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
AutoMM has created your model. 🎉🎉🎉
To load the model, use the code below:
```python
from autogluon.multimodal import MultiModalPredictor
predictor = MultiModalPredictor.load("/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20250508_212310")
```
If you are not satisfied with the model, try to increase the training time,
adjust the hyperparameters (https://autogluon.cn/stable/tutorials/multimodal/advanced_topics/customization.html),
or post issues on GitHub (https://github.com/autogluon/autogluon/issues).
Tune 状态
当前时间 | 2025-05-08 21:24:51 |
已运行 | 00:01:34.82 |
内存 | 5.2/30.9 GiB |
系统信息
使用 AsyncHyperBand: num_stopped=0Bracket: Iter 4096.000: None | Iter 1024.000: None | Iter 256.000: None | Iter 64.000: None | Iter 16.000: 0.800000011920929 | Iter 4.000: 0.6156249791383743 | Iter 1.000: 0.26249998807907104
逻辑资源使用情况: 8.0/8 CPUs, 1.0/1 GPUs (0.0/1.0 accelerator_type:T4)
试验状态
试验名称 | 状态 | 位置 | 模型名称 | model.timm_image.che ckpoint_name | optim.lr | 迭代 | 总时间 (秒) | 验证准确率 |
---|---|---|---|---|---|---|---|---|
13925b61 | 已终止 | 10.0.1.51:5412 | ('timm_image', _9d40 | mobilenetv3_lar_abf0 | 6.48954e-05 | 38 | 50.5585 | 0.6625 |
0960fed9 | 已终止 | 10.0.1.51:5617 | ('timm_image', _00c0 | mobilenetv3_lar_abf0 | 0.000833069 | 18 | 24.3515 | 0.85 |
试验进度
试验名称 | should_checkpoint | 验证准确率 |
---|---|---|
0960fed9 | True | 0.85 |
13925b61 | True | 0.6625 |
Total fitting time: 1m43s
让我们看看 HPO 后拟合模型的测试准确率
scores_hpo = predictor_hpo.evaluate(test_data, metrics=["accuracy"])
print('Top-1 test acc: %.3f' % scores_hpo["accuracy"])
Top-1 test acc: 0.875
Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
从训练日志中,您应该能看到当前最佳试验,如下所示
Current best trial: 47aef96a with val_accuracy=0.862500011920929 and parameters={'optim.lr': 0.0007195214018085505, 'model.timm_image.checkpoint_name': 'ghostnet_100'}
在我们简单的 2 次试验 HPO 运行后,通过搜索不同的学习率和模型,我们获得了比上一节提供的即用型解决方案更好的测试准确率。HPO 有助于选择具有最高验证准确率的超参数组合。
其他示例¶
您可以访问 AutoMM 示例 查看更多关于 AutoMM 的示例。
定制¶
要了解如何定制 AutoMM,请参阅 定制 AutoMM。