图像分类的 AutoMM - 快速入门

Open In Colab Open In SageMaker Studio Lab

在本快速入门中,我们将使用图像分类任务来说明如何使用 MultiModalPredictor。一旦数据准备成 Pandas DataFrame 格式,只需调用一次 MultiModalPredictor.fit(),即可为您完成模型训练。

创建图像数据集

为了演示目的,我们使用了 Kaggle 上的 Shopee-IET 数据集 的一个子集。该数据集中的每张图像都描绘了一件服装,相应的标签指定了其服装类别。我们使用的数据子集包含以下可能的标签:BabyPantsBabyShirtwomencasualshoeswomenchiffontop

我们可以通过自动下载 url 数据来加载数据集

import warnings
warnings.filterwarnings('ignore')
import pandas as pd

from autogluon.multimodal.utils.misc import shopee_dataset
download_dir = './ag_automm_tutorial_imgcls'
train_data_path, test_data_path = shopee_dataset(download_dir)
print(train_data_path)
Downloading ./ag_automm_tutorial_imgcls/file.zip from https://automl-mm-bench.s3.amazonaws.com/vision_datasets/shopee.zip...
                                                 image  label
0    /home/ci/autogluon/docs/tutorials/multimodal/i...      0
1    /home/ci/autogluon/docs/tutorials/multimodal/i...      0
2    /home/ci/autogluon/docs/tutorials/multimodal/i...      0
3    /home/ci/autogluon/docs/tutorials/multimodal/i...      0
4    /home/ci/autogluon/docs/tutorials/multimodal/i...      0
..                                                 ...    ...
795  /home/ci/autogluon/docs/tutorials/multimodal/i...      3
796  /home/ci/autogluon/docs/tutorials/multimodal/i...      3
797  /home/ci/autogluon/docs/tutorials/multimodal/i...      3
798  /home/ci/autogluon/docs/tutorials/multimodal/i...      3
799  /home/ci/autogluon/docs/tutorials/multimodal/i...      3

[800 rows x 2 columns]
  0%|          | 0.00/84.0M [00:00<?, ?iB/s]
 10%|▉         | 8.38M/84.0M [00:00<00:02, 30.2MiB/s]
 21%|██        | 17.4M/84.0M [00:00<00:01, 50.5MiB/s]
 32%|███▏      | 26.8M/84.0M [00:00<00:00, 64.6MiB/s]
 41%|████      | 34.6M/84.0M [00:00<00:01, 47.2MiB/s]
 50%|████▉     | 41.9M/84.0M [00:00<00:00, 46.1MiB/s]
 60%|█████▉    | 50.3M/84.0M [00:00<00:00, 54.5MiB/s]
 67%|██████▋   | 56.7M/84.0M [00:01<00:00, 47.9MiB/s]
 74%|███████▍  | 62.2M/84.0M [00:01<00:00, 47.3MiB/s]
 80%|████████  | 67.4M/84.0M [00:01<00:00, 44.5MiB/s]
 90%|████████▉ | 75.5M/84.0M [00:01<00:00, 50.8MiB/s]
 98%|█████████▊| 82.1M/84.0M [00:01<00:00, 54.5MiB/s]
100%|██████████| 84.0M/84.0M [00:01<00:00, 49.1MiB/s]

我们可以看到,这个训练 DataFrame 中有 800 行和 2 列。这两列是 imagelabel,其中 image 列包含图像的绝对路径。每一行代表一个不同的训练样本。

除了图像路径,MultiModalPredictor 在训练和推理期间也支持图像字节数组。我们可以通过将选项 is_bytearray 设置为 True 来加载包含字节数组的数据集。

import warnings
warnings.filterwarnings('ignore')

download_dir = './ag_automm_tutorial_imgcls'
train_data_byte, test_data_byte = shopee_dataset(download_dir, is_bytearray=True)

使用 AutoMM 拟合模型

现在,我们使用 AutoMM 拟合一个分类器,如下所示

from autogluon.multimodal import MultiModalPredictor
import uuid
model_path = f"./tmp/{uuid.uuid4().hex}-automm_shopee"
predictor = MultiModalPredictor(label="label", path=model_path)
predictor.fit(
    train_data=train_data_path,
    time_limit=30, # seconds
)
=================== System Info ===================
AutoGluon Version:  1.3.1b20250508
Python Version:     3.11.9
Operating System:   Linux
Platform Machine:   x86_64
Platform Version:   #1 SMP Wed Mar 12 14:53:59 UTC 2025
CPU Count:          8
Pytorch Version:    2.6.0+cu124
CUDA Version:       12.4
Memory Avail:       28.36 GB / 30.95 GB (91.6%)
Disk Space Avail:   185.13 GB / 255.99 GB (72.3%)
===================================================
AutoGluon infers your prediction problem is: 'multiclass' (because dtype of label-column == int, but few unique label-values observed).
	4 unique label values:  [np.int64(0), np.int64(1), np.int64(2), np.int64(3)]
	If 'multiclass' is not the correct problem_type, please manually specify the problem_type parameter during Predictor init (You may specify problem_type as one of: ['binary', 'multiclass', 'regression', 'quantile'])

AutoMM starts to create your model. ✨✨✨

To track the learning progress, you can open a terminal and launch Tensorboard:
    ```shell
    # Assume you have installed tensorboard
    tensorboard --logdir /home/ci/autogluon/docs/tutorials/multimodal/image_prediction/tmp/d5d99ea28a0f473194328ad56a3b2732-automm_shopee
    ```
Seed set to 0
GPU Count: 1
GPU Count to be Used: 1
Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name              | Type                            | Params | Mode 
------------------------------------------------------------------------------
0 | model             | TimmAutoModelForImagePrediction | 95.7 M | train
1 | validation_metric | MulticlassAccuracy              | 0      | train
2 | loss_func         | CrossEntropyLoss                | 0      | train
------------------------------------------------------------------------------
95.7 M    Trainable params
0         Non-trainable params
95.7 M    Total params
382.772   Total estimated model params size (MB)
863       Modules in train mode
0         Modules in eval mode
Epoch 0, global step 2: 'val_accuracy' reached 0.22500 (best 0.22500), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/image_prediction/tmp/d5d99ea28a0f473194328ad56a3b2732-automm_shopee/epoch=0-step=2.ckpt' as top 3
Epoch 0, global step 5: 'val_accuracy' reached 0.82500 (best 0.82500), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/image_prediction/tmp/d5d99ea28a0f473194328ad56a3b2732-automm_shopee/epoch=0-step=5.ckpt' as top 3
Time limit reached. Elapsed time is 0:00:33. Signaling Trainer to stop.
Start to fuse 2 checkpoints via the greedy soup algorithm.
Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
AutoMM has created your model. 🎉🎉🎉

To load the model, use the code below:
    ```python
    from autogluon.multimodal import MultiModalPredictor
    predictor = MultiModalPredictor.load("/home/ci/autogluon/docs/tutorials/multimodal/image_prediction/tmp/d5d99ea28a0f473194328ad56a3b2732-automm_shopee")
    ```

If you are not satisfied with the model, try to increase the training time, 
adjust the hyperparameters (https://autogluon.cn/stable/tutorials/multimodal/advanced_topics/customization.html),
or post issues on GitHub (https://github.com/autogluon/autogluon/issues).
<autogluon.multimodal.predictor.MultiModalPredictor at 0x7fb7e3c4cb50>

label 是包含要预测的目标变量的列的名称,例如,在我们示例中它是“label”。path 指示应保存模型和中间输出的目录。为了演示目的,我们将训练时间限制设置为 30 秒,但您可以通过设置配置来控制训练时间。要自定义 AutoMM,请参阅 自定义 AutoMM

在测试数据集上评估

您可以在测试数据集上评估分类器,看看它的表现如何,测试的 top-1 准确率是

scores = predictor.evaluate(test_data_path, metrics=["accuracy"])
print('Top-1 test acc: %.3f' % scores["accuracy"])
Top-1 test acc: 0.800
Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.

您也可以使用在包含图像路径的训练数据上训练的模型,在包含图像字节数组的测试数据上进行评估,反之亦然。

scores = predictor.evaluate(test_data_byte, metrics=["accuracy"])
print('Top-1 test acc: %.3f' % scores["accuracy"])
Top-1 test acc: 0.800
Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.

对新图像进行预测

给定一个示例图像,我们首先可视化它,

image_path = test_data_path.iloc[0]['image']
from IPython.display import Image, display
pil_img = Image(filename=image_path)
display(pil_img)
../../../_images/cbf97c376c1390b8ec9915b951f92171265d7907bda59cdf411ce0068ae84ea2.jpg

我们可以轻松地使用最终模型来 predict 标签,

predictions = predictor.predict({'image': [image_path]})
print(predictions)
[0]
Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.

如果需要所有类别的概率,您可以调用 predict_proba

proba = predictor.predict_proba({'image': [image_path]})
print(proba)
[[0.32413834 0.30989963 0.18693395 0.17902805]]
Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.

predictor.evaluate 类似,我们也可以将图像字节数组解析到 .predict.predict_proba

image_byte = test_data_byte.iloc[0]['image']
predictions = predictor.predict({'image': [image_byte]})
print(predictions)

proba = predictor.predict_proba({'image': [image_byte]})
print(proba)
[0]
[[0.32413834 0.30989963 0.18693395 0.17902805]]
Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.

提取嵌入

从模型学习到的整个图像中提取表示也非常有用。我们提供了 extract_embedding 函数,允许预测器返回 N 维图像特征,其中 N 取决于模型(通常是长度为 512 到 2048 的向量)。

feature = predictor.extract_embedding({'image': [image_path]})
print(feature[0].shape)
(768,)
Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.

从图像字节数组提取嵌入时,您应该会得到相同的结果

feature = predictor.extract_embedding({'image': [image_byte]})
print(feature[0].shape)
(768,)
Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.

保存和加载

训练好的预测器在 fit() 结束时会自动保存,您可以轻松地重新加载它。

警告

MultiModalPredictor.load() 隐式使用了 pickle 模块,该模块已知不安全。可以构造恶意 pickle 数据,在反序列化期间执行任意代码。切勿加载可能来自不受信任来源或可能已被篡改的数据。仅加载您信任的数据。

loaded_predictor = MultiModalPredictor.load(model_path)
load_proba = loaded_predictor.predict_proba({'image': [image_path]})
print(load_proba)
Load pretrained checkpoint: /home/ci/autogluon/docs/tutorials/multimodal/image_prediction/tmp/d5d99ea28a0f473194328ad56a3b2732-automm_shopee/model.ckpt
Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
[[0.32413834 0.30989963 0.18693395 0.17902805]]

我们可以看到预测的类别概率仍然与上面相同,这意味着是同一个模型!

其他示例

您可以前往 AutoMM 示例 探索更多关于 AutoMM 的示例。

自定义

要了解如何自定义 AutoMM,请参阅 自定义 AutoMM