AutoMM 用于图像 + 文本 + 表格 - 快速入门

Open In Colab Open In SageMaker Studio Lab

AutoMM 是一个深度学习的“模型动物园中的模型动物园”。它可以自动构建适合处理多模态数据集的深度学习模型。您只需将数据转换为多模态 dataframe 格式,AutoMM 就可以根据包括图像、文本和表格数据在内的其他列的特征来预测某一列的值。

import os
import numpy as np
import warnings
warnings.filterwarnings('ignore')
np.random.seed(123)

数据集

为了演示,我们使用了 PetFinder 数据集的简化子样本版本。任务是根据动物的领养资料信息预测它们的领养率。在这个简化版本中,领养速度被分为两类:0(慢)和 1(快)。

首先,让我们下载并准备数据集。

download_dir = './ag_automm_tutorial'
zip_file = 'https://automl-mm-bench.s3.amazonaws.com/petfinder_for_tutorial.zip'
from autogluon.core.utils.loaders import load_zip
load_zip.unzip(zip_file, unzip_dir=download_dir)
Downloading ./ag_automm_tutorial/file.zip from https://automl-mm-bench.s3.amazonaws.com/petfinder_for_tutorial.zip...
  0%|          | 0.00/18.8M [00:00<?, ?iB/s]
 35%|███▌      | 6.63M/18.8M [00:00<00:00, 38.4MiB/s]
 56%|█████▌    | 10.5M/18.8M [00:00<00:00, 34.3MiB/s]
 89%|████████▉ | 16.8M/18.8M [00:00<00:00, 38.0MiB/s]
100%|██████████| 18.8M/18.8M [00:00<00:00, 40.3MiB/s]

接下来,我们将加载 CSV 文件。

import pandas as pd
dataset_path = download_dir + '/petfinder_for_tutorial'
train_data = pd.read_csv(f'{dataset_path}/train.csv', index_col=0)
test_data = pd.read_csv(f'{dataset_path}/test.csv', index_col=0)
label_col = 'AdoptionSpeed'

我们需要展开图像路径以便在训练中加载它们。

image_col = 'Images'
train_data[image_col] = train_data[image_col].apply(lambda ele: ele.split(';')[0]) # Use the first image for a quick tutorial
test_data[image_col] = test_data[image_col].apply(lambda ele: ele.split(';')[0])


def path_expander(path, base_folder):
    path_l = path.split(';')
    return ';'.join([os.path.abspath(os.path.join(base_folder, path)) for path in path_l])

train_data[image_col] = train_data[image_col].apply(lambda ele: path_expander(ele, base_folder=dataset_path))
test_data[image_col] = test_data[image_col].apply(lambda ele: path_expander(ele, base_folder=dataset_path))

train_data[image_col].iloc[0]
'/home/ci/autogluon/docs/tutorials/multimodal/multimodal_prediction/ag_automm_tutorial/petfinder_for_tutorial/images/7d7a39d71-1.jpg'

每只动物的领养资料包括图片、文本描述以及各种表格特征,如年龄、品种、名字、颜色等。让我们看一行数据示例,并显示文本描述和一张图片。

example_row = train_data.iloc[0]

example_row
Type                                                             2
Name                                                 Yumi Hamasaki
Age                                                              4
Breed1                                                         292
Breed2                                                         265
Gender                                                           2
Color1                                                           1
Color2                                                           5
Color3                                                           7
MaturitySize                                                     2
FurLength                                                        2
Vaccinated                                                       1
Dewormed                                                         3
Sterilized                                                       2
Health                                                           1
Quantity                                                         1
Fee                                                              0
State                                                        41326
RescuerID                         bcc4e1b9557a8b3aaf545ea8e6e86991
VideoAmt                                                         0
Description      I rescued Yumi Hamasaki at a food stall far aw...
PetID                                                    7d7a39d71
PhotoAmt                                                       3.0
AdoptionSpeed                                                    0
Images           /home/ci/autogluon/docs/tutorials/multimodal/m...
Name: 0, dtype: object
example_row['Description']
"I rescued Yumi Hamasaki at a food stall far away in Kelantan. At that time i was on my way back to KL, she was suffer from stomach problem and looking very2 sick.. I send her to vet & get the treatment + vaccinated and right now she's very2 healthy.. About yumi : - love to sleep with ppl - she will keep on meowing if she's hugry - very2 active, always seeking for people to accompany her playing - well trained (poo+pee in her own potty) - easy to bathing - I only feed her with these brands : IAMS, Kittenbites, Pro-formance Reason why i need someone to adopt Yumi: I just married and need to move to a new house where no pets are allowed :( As Yumi is very2 special to me, i will only give her to ppl that i think could take care of her just like i did (especially on her foods things).."
example_image = example_row[image_col]

from IPython.display import Image, display
pil_img = Image(filename=example_image)
display(pil_img)
../../../_images/2c51d770716edcf63f3988523cbc934d4fc245ed0f8ee91864e7d7da6f3f1c81.jpg

训练

现在我们用训练数据来拟合预测器。这里我们设定了一个紧凑的时间预算以便进行快速演示。

from autogluon.multimodal import MultiModalPredictor
predictor = MultiModalPredictor(label=label_col)
predictor.fit(
    train_data=train_data,
    time_limit=120, # seconds
)
No path specified. Models will be saved in: "AutogluonModels/ag-20250508_205717"
=================== System Info ===================
AutoGluon Version:  1.3.1b20250508
Python Version:     3.11.9
Operating System:   Linux
Platform Machine:   x86_64
Platform Version:   #1 SMP Wed Mar 12 14:53:59 UTC 2025
CPU Count:          8
Pytorch Version:    2.6.0+cu124
CUDA Version:       12.4
Memory Avail:       28.40 GB / 30.95 GB (91.8%)
Disk Space Avail:   185.23 GB / 255.99 GB (72.4%)
===================================================
AutoGluon infers your prediction problem is: 'binary' (because only two unique label-values observed).
	2 unique label values:  [np.int64(0), np.int64(1)]
	If 'binary' is not the correct problem_type, please manually specify the problem_type parameter during Predictor init (You may specify problem_type as one of: ['binary', 'multiclass', 'regression', 'quantile'])

AutoMM starts to create your model. ✨✨✨

To track the learning progress, you can open a terminal and launch Tensorboard:
    ```shell
    # Assume you have installed tensorboard
    tensorboard --logdir /home/ci/autogluon/docs/tutorials/multimodal/multimodal_prediction/AutogluonModels/ag-20250508_205717
    ```
INFO: Seed set to 0
GPU Count: 1
GPU Count to be Used: 1
INFO: Using 16bit Automatic Mixed Precision (AMP)
INFO: GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO: HPU available: False, using: 0 HPUs
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO: 
  | Name              | Type                | Params | Mode 
------------------------------------------------------------------
0 | model             | MultimodalFusionMLP | 207 M  | train
1 | validation_metric | BinaryAUROC         | 0      | train
2 | loss_func         | CrossEntropyLoss    | 0      | train
------------------------------------------------------------------
207 M     Trainable params
0         Non-trainable params
207 M     Total params
828.307   Total estimated model params size (MB)
1171      Modules in train mode
0         Modules in eval mode
INFO: Epoch 0, global step 1: 'val_roc_auc' reached 0.56250 (best 0.56250), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/multimodal_prediction/AutogluonModels/ag-20250508_205717/epoch=0-step=1.ckpt' as top 3
INFO: Epoch 0, global step 4: 'val_roc_auc' reached 0.76083 (best 0.76083), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/multimodal_prediction/AutogluonModels/ag-20250508_205717/epoch=0-step=4.ckpt' as top 3
INFO: Epoch 1, global step 5: 'val_roc_auc' reached 0.78056 (best 0.78056), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/multimodal_prediction/AutogluonModels/ag-20250508_205717/epoch=1-step=5.ckpt' as top 3
INFO: Epoch 1, global step 8: 'val_roc_auc' reached 0.79000 (best 0.79000), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/multimodal_prediction/AutogluonModels/ag-20250508_205717/epoch=1-step=8.ckpt' as top 3
INFO: Epoch 2, global step 9: 'val_roc_auc' reached 0.78972 (best 0.79000), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/multimodal_prediction/AutogluonModels/ag-20250508_205717/epoch=2-step=9.ckpt' as top 3
INFO: Time limit reached. Elapsed time is 0:02:18. Signaling Trainer to stop.
Start to fuse 3 checkpoints via the greedy soup algorithm.
INFO: Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
INFO: Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
INFO: Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
AutoMM has created your model. 🎉🎉🎉

To load the model, use the code below:
    ```python
    from autogluon.multimodal import MultiModalPredictor
    predictor = MultiModalPredictor.load("/home/ci/autogluon/docs/tutorials/multimodal/multimodal_prediction/AutogluonModels/ag-20250508_205717")
    ```

If you are not satisfied with the model, try to increase the training time, 
adjust the hyperparameters (https://autogluon.cn/stable/tutorials/multimodal/advanced_topics/customization.html),
or post issues on GitHub (https://github.com/autogluon/autogluon/issues).
<autogluon.multimodal.predictor.MultiModalPredictor at 0x7f913c137a50>

在底层,AutoMM 自动推断问题类型(分类或回归),检测数据模态,从多模态模型池中选择相关模型,并训练所选模型。如果存在多个骨干网络,AutoMM 会在它们之上添加一个后期融合模型(MLP 或 transformer)。

评估

然后我们可以在测试数据上评估预测器。

scores = predictor.evaluate(test_data, metrics=["roc_auc"])
scores
INFO: Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
{'roc_auc': np.float64(0.9004)}

预测

给定一个没有标签列的多模态 dataframe,我们可以预测标签。

predictions = predictor.predict(test_data.drop(columns=label_col))
predictions[:5]
INFO: Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
8     1
70    1
82    1
28    0
63    1
Name: AdoptionSpeed, dtype: int64

对于分类任务,我们可以获得所有类别的概率。

probas = predictor.predict_proba(test_data.drop(columns=label_col))
probas[:5]
INFO: Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
0 1
8 0.372848 0.627152
70 0.197599 0.802401
82 0.018618 0.981382
28 0.796530 0.203470
63 0.165020 0.834980

请注意,对回归任务调用 .predict_proba() 将会抛出异常。

提取嵌入

提取嵌入在许多情况下也非常有用,我们可以将每个样本(dataframe 中的每一行)转换为一个嵌入向量。

embeddings = predictor.extract_embedding(test_data.drop(columns=label_col))
embeddings.shape
INFO: Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
(100, 128)

保存和加载

保存预测器并重新加载也很方便。

警告

MultiModalPredictor.load() 隐式使用 pickle 模块,该模块已知不安全。可以构造恶意 pickle 数据,在反序列化过程中执行任意代码。切勿加载可能来自不受信任源或可能被篡改的数据。只加载您信任的数据。

import uuid

model_path = f"./tmp/{uuid.uuid4().hex}-saved_model"
predictor.save(model_path)
loaded_predictor = MultiModalPredictor.load(model_path)
scores2 = loaded_predictor.evaluate(test_data, metrics=["roc_auc"])
scores2
Load pretrained checkpoint: /home/ci/autogluon/docs/tutorials/multimodal/multimodal_prediction/tmp/67db6a334113439b8d1da942caa785ec-saved_model/model.ckpt
INFO: Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
{'roc_auc': np.float64(0.9004000000000001)}

其他示例

您可以访问 AutoMM 示例 来探索 AutoMM 的其他示例。

自定义

要了解如何自定义 AutoMM,请参阅 自定义 AutoMM