图像分类的 AutoMM - 快速入门¶
在本快速入门中,我们将使用图像分类任务来说明如何使用 MultiModalPredictor。一旦数据准备成 Pandas DataFrame 格式,只需调用一次 MultiModalPredictor.fit()
,即可为您完成模型训练。
创建图像数据集¶
为了演示目的,我们使用了 Kaggle 上的 Shopee-IET 数据集 的一个子集。该数据集中的每张图像都描绘了一件服装,相应的标签指定了其服装类别。我们使用的数据子集包含以下可能的标签:BabyPants
、BabyShirt
、womencasualshoes
、womenchiffontop
。
我们可以通过自动下载 url 数据来加载数据集
import warnings
warnings.filterwarnings('ignore')
import pandas as pd
from autogluon.multimodal.utils.misc import shopee_dataset
download_dir = './ag_automm_tutorial_imgcls'
train_data_path, test_data_path = shopee_dataset(download_dir)
print(train_data_path)
Downloading ./ag_automm_tutorial_imgcls/file.zip from https://automl-mm-bench.s3.amazonaws.com/vision_datasets/shopee.zip...
image label
0 /home/ci/autogluon/docs/tutorials/multimodal/i... 0
1 /home/ci/autogluon/docs/tutorials/multimodal/i... 0
2 /home/ci/autogluon/docs/tutorials/multimodal/i... 0
3 /home/ci/autogluon/docs/tutorials/multimodal/i... 0
4 /home/ci/autogluon/docs/tutorials/multimodal/i... 0
.. ... ...
795 /home/ci/autogluon/docs/tutorials/multimodal/i... 3
796 /home/ci/autogluon/docs/tutorials/multimodal/i... 3
797 /home/ci/autogluon/docs/tutorials/multimodal/i... 3
798 /home/ci/autogluon/docs/tutorials/multimodal/i... 3
799 /home/ci/autogluon/docs/tutorials/multimodal/i... 3
[800 rows x 2 columns]
0%| | 0.00/84.0M [00:00<?, ?iB/s]
10%|▉ | 8.38M/84.0M [00:00<00:02, 30.2MiB/s]
21%|██ | 17.4M/84.0M [00:00<00:01, 50.5MiB/s]
32%|███▏ | 26.8M/84.0M [00:00<00:00, 64.6MiB/s]
41%|████ | 34.6M/84.0M [00:00<00:01, 47.2MiB/s]
50%|████▉ | 41.9M/84.0M [00:00<00:00, 46.1MiB/s]
60%|█████▉ | 50.3M/84.0M [00:00<00:00, 54.5MiB/s]
67%|██████▋ | 56.7M/84.0M [00:01<00:00, 47.9MiB/s]
74%|███████▍ | 62.2M/84.0M [00:01<00:00, 47.3MiB/s]
80%|████████ | 67.4M/84.0M [00:01<00:00, 44.5MiB/s]
90%|████████▉ | 75.5M/84.0M [00:01<00:00, 50.8MiB/s]
98%|█████████▊| 82.1M/84.0M [00:01<00:00, 54.5MiB/s]
100%|██████████| 84.0M/84.0M [00:01<00:00, 49.1MiB/s]
我们可以看到,这个训练 DataFrame 中有 800 行和 2 列。这两列是 image 和 label,其中 image 列包含图像的绝对路径。每一行代表一个不同的训练样本。
除了图像路径,MultiModalPredictor
在训练和推理期间也支持图像字节数组。我们可以通过将选项 is_bytearray
设置为 True
来加载包含字节数组的数据集。
import warnings
warnings.filterwarnings('ignore')
download_dir = './ag_automm_tutorial_imgcls'
train_data_byte, test_data_byte = shopee_dataset(download_dir, is_bytearray=True)
使用 AutoMM 拟合模型¶
现在,我们使用 AutoMM 拟合一个分类器,如下所示
from autogluon.multimodal import MultiModalPredictor
import uuid
model_path = f"./tmp/{uuid.uuid4().hex}-automm_shopee"
predictor = MultiModalPredictor(label="label", path=model_path)
predictor.fit(
train_data=train_data_path,
time_limit=30, # seconds
)
=================== System Info ===================
AutoGluon Version: 1.3.1b20250508
Python Version: 3.11.9
Operating System: Linux
Platform Machine: x86_64
Platform Version: #1 SMP Wed Mar 12 14:53:59 UTC 2025
CPU Count: 8
Pytorch Version: 2.6.0+cu124
CUDA Version: 12.4
Memory Avail: 28.36 GB / 30.95 GB (91.6%)
Disk Space Avail: 185.13 GB / 255.99 GB (72.3%)
===================================================
AutoGluon infers your prediction problem is: 'multiclass' (because dtype of label-column == int, but few unique label-values observed).
4 unique label values: [np.int64(0), np.int64(1), np.int64(2), np.int64(3)]
If 'multiclass' is not the correct problem_type, please manually specify the problem_type parameter during Predictor init (You may specify problem_type as one of: ['binary', 'multiclass', 'regression', 'quantile'])
AutoMM starts to create your model. ✨✨✨
To track the learning progress, you can open a terminal and launch Tensorboard:
```shell
# Assume you have installed tensorboard
tensorboard --logdir /home/ci/autogluon/docs/tutorials/multimodal/image_prediction/tmp/d5d99ea28a0f473194328ad56a3b2732-automm_shopee
```
Seed set to 0
GPU Count: 1
GPU Count to be Used: 1
Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
| Name | Type | Params | Mode
------------------------------------------------------------------------------
0 | model | TimmAutoModelForImagePrediction | 95.7 M | train
1 | validation_metric | MulticlassAccuracy | 0 | train
2 | loss_func | CrossEntropyLoss | 0 | train
------------------------------------------------------------------------------
95.7 M Trainable params
0 Non-trainable params
95.7 M Total params
382.772 Total estimated model params size (MB)
863 Modules in train mode
0 Modules in eval mode
Epoch 0, global step 2: 'val_accuracy' reached 0.22500 (best 0.22500), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/image_prediction/tmp/d5d99ea28a0f473194328ad56a3b2732-automm_shopee/epoch=0-step=2.ckpt' as top 3
Epoch 0, global step 5: 'val_accuracy' reached 0.82500 (best 0.82500), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/image_prediction/tmp/d5d99ea28a0f473194328ad56a3b2732-automm_shopee/epoch=0-step=5.ckpt' as top 3
Time limit reached. Elapsed time is 0:00:33. Signaling Trainer to stop.
Start to fuse 2 checkpoints via the greedy soup algorithm.
Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
AutoMM has created your model. 🎉🎉🎉
To load the model, use the code below:
```python
from autogluon.multimodal import MultiModalPredictor
predictor = MultiModalPredictor.load("/home/ci/autogluon/docs/tutorials/multimodal/image_prediction/tmp/d5d99ea28a0f473194328ad56a3b2732-automm_shopee")
```
If you are not satisfied with the model, try to increase the training time,
adjust the hyperparameters (https://autogluon.cn/stable/tutorials/multimodal/advanced_topics/customization.html),
or post issues on GitHub (https://github.com/autogluon/autogluon/issues).
<autogluon.multimodal.predictor.MultiModalPredictor at 0x7fb7e3c4cb50>
label 是包含要预测的目标变量的列的名称,例如,在我们示例中它是“label”。path 指示应保存模型和中间输出的目录。为了演示目的,我们将训练时间限制设置为 30 秒,但您可以通过设置配置来控制训练时间。要自定义 AutoMM,请参阅 自定义 AutoMM。
在测试数据集上评估¶
您可以在测试数据集上评估分类器,看看它的表现如何,测试的 top-1 准确率是
scores = predictor.evaluate(test_data_path, metrics=["accuracy"])
print('Top-1 test acc: %.3f' % scores["accuracy"])
Top-1 test acc: 0.800
Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
您也可以使用在包含图像路径的训练数据上训练的模型,在包含图像字节数组的测试数据上进行评估,反之亦然。
scores = predictor.evaluate(test_data_byte, metrics=["accuracy"])
print('Top-1 test acc: %.3f' % scores["accuracy"])
Top-1 test acc: 0.800
Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
对新图像进行预测¶
给定一个示例图像,我们首先可视化它,
image_path = test_data_path.iloc[0]['image']
from IPython.display import Image, display
pil_img = Image(filename=image_path)
display(pil_img)

我们可以轻松地使用最终模型来 predict
标签,
predictions = predictor.predict({'image': [image_path]})
print(predictions)
[0]
Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
如果需要所有类别的概率,您可以调用 predict_proba
proba = predictor.predict_proba({'image': [image_path]})
print(proba)
[[0.32413834 0.30989963 0.18693395 0.17902805]]
Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
与 predictor.evaluate
类似,我们也可以将图像字节数组解析到 .predict
和 .predict_proba
中
image_byte = test_data_byte.iloc[0]['image']
predictions = predictor.predict({'image': [image_byte]})
print(predictions)
proba = predictor.predict_proba({'image': [image_byte]})
print(proba)
[0]
[[0.32413834 0.30989963 0.18693395 0.17902805]]
Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
提取嵌入¶
从模型学习到的整个图像中提取表示也非常有用。我们提供了 extract_embedding
函数,允许预测器返回 N 维图像特征,其中 N
取决于模型(通常是长度为 512 到 2048 的向量)。
feature = predictor.extract_embedding({'image': [image_path]})
print(feature[0].shape)
(768,)
Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
从图像字节数组提取嵌入时,您应该会得到相同的结果
feature = predictor.extract_embedding({'image': [image_byte]})
print(feature[0].shape)
(768,)
Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
保存和加载¶
训练好的预测器在 fit()
结束时会自动保存,您可以轻松地重新加载它。
警告
MultiModalPredictor.load()
隐式使用了 pickle
模块,该模块已知不安全。可以构造恶意 pickle 数据,在反序列化期间执行任意代码。切勿加载可能来自不受信任来源或可能已被篡改的数据。仅加载您信任的数据。
loaded_predictor = MultiModalPredictor.load(model_path)
load_proba = loaded_predictor.predict_proba({'image': [image_path]})
print(load_proba)
Load pretrained checkpoint: /home/ci/autogluon/docs/tutorials/multimodal/image_prediction/tmp/d5d99ea28a0f473194328ad56a3b2732-automm_shopee/model.ckpt
Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
[[0.32413834 0.30989963 0.18693395 0.17902805]]
我们可以看到预测的类别概率仍然与上面相同,这意味着是同一个模型!
其他示例¶
您可以前往 AutoMM 示例 探索更多关于 AutoMM 的示例。
自定义¶
要了解如何自定义 AutoMM,请参阅 自定义 AutoMM。