AutoMM 用于图像 + 文本 + 表格 - 快速入门¶
AutoMM 是一个深度学习的“模型动物园中的模型动物园”。它可以自动构建适合处理多模态数据集的深度学习模型。您只需将数据转换为多模态 dataframe 格式,AutoMM 就可以根据包括图像、文本和表格数据在内的其他列的特征来预测某一列的值。
import os
import numpy as np
import warnings
warnings.filterwarnings('ignore')
np.random.seed(123)
数据集¶
为了演示,我们使用了 PetFinder 数据集的简化子样本版本。任务是根据动物的领养资料信息预测它们的领养率。在这个简化版本中,领养速度被分为两类:0(慢)和 1(快)。
首先,让我们下载并准备数据集。
download_dir = './ag_automm_tutorial'
zip_file = 'https://automl-mm-bench.s3.amazonaws.com/petfinder_for_tutorial.zip'
from autogluon.core.utils.loaders import load_zip
load_zip.unzip(zip_file, unzip_dir=download_dir)
Downloading ./ag_automm_tutorial/file.zip from https://automl-mm-bench.s3.amazonaws.com/petfinder_for_tutorial.zip...
0%| | 0.00/18.8M [00:00<?, ?iB/s]
35%|███▌ | 6.63M/18.8M [00:00<00:00, 38.4MiB/s]
56%|█████▌ | 10.5M/18.8M [00:00<00:00, 34.3MiB/s]
89%|████████▉ | 16.8M/18.8M [00:00<00:00, 38.0MiB/s]
100%|██████████| 18.8M/18.8M [00:00<00:00, 40.3MiB/s]
接下来,我们将加载 CSV 文件。
import pandas as pd
dataset_path = download_dir + '/petfinder_for_tutorial'
train_data = pd.read_csv(f'{dataset_path}/train.csv', index_col=0)
test_data = pd.read_csv(f'{dataset_path}/test.csv', index_col=0)
label_col = 'AdoptionSpeed'
我们需要展开图像路径以便在训练中加载它们。
image_col = 'Images'
train_data[image_col] = train_data[image_col].apply(lambda ele: ele.split(';')[0]) # Use the first image for a quick tutorial
test_data[image_col] = test_data[image_col].apply(lambda ele: ele.split(';')[0])
def path_expander(path, base_folder):
path_l = path.split(';')
return ';'.join([os.path.abspath(os.path.join(base_folder, path)) for path in path_l])
train_data[image_col] = train_data[image_col].apply(lambda ele: path_expander(ele, base_folder=dataset_path))
test_data[image_col] = test_data[image_col].apply(lambda ele: path_expander(ele, base_folder=dataset_path))
train_data[image_col].iloc[0]
'/home/ci/autogluon/docs/tutorials/multimodal/multimodal_prediction/ag_automm_tutorial/petfinder_for_tutorial/images/7d7a39d71-1.jpg'
每只动物的领养资料包括图片、文本描述以及各种表格特征,如年龄、品种、名字、颜色等。让我们看一行数据示例,并显示文本描述和一张图片。
example_row = train_data.iloc[0]
example_row
Type 2
Name Yumi Hamasaki
Age 4
Breed1 292
Breed2 265
Gender 2
Color1 1
Color2 5
Color3 7
MaturitySize 2
FurLength 2
Vaccinated 1
Dewormed 3
Sterilized 2
Health 1
Quantity 1
Fee 0
State 41326
RescuerID bcc4e1b9557a8b3aaf545ea8e6e86991
VideoAmt 0
Description I rescued Yumi Hamasaki at a food stall far aw...
PetID 7d7a39d71
PhotoAmt 3.0
AdoptionSpeed 0
Images /home/ci/autogluon/docs/tutorials/multimodal/m...
Name: 0, dtype: object
example_row['Description']
"I rescued Yumi Hamasaki at a food stall far away in Kelantan. At that time i was on my way back to KL, she was suffer from stomach problem and looking very2 sick.. I send her to vet & get the treatment + vaccinated and right now she's very2 healthy.. About yumi : - love to sleep with ppl - she will keep on meowing if she's hugry - very2 active, always seeking for people to accompany her playing - well trained (poo+pee in her own potty) - easy to bathing - I only feed her with these brands : IAMS, Kittenbites, Pro-formance Reason why i need someone to adopt Yumi: I just married and need to move to a new house where no pets are allowed :( As Yumi is very2 special to me, i will only give her to ppl that i think could take care of her just like i did (especially on her foods things).."
example_image = example_row[image_col]
from IPython.display import Image, display
pil_img = Image(filename=example_image)
display(pil_img)

训练¶
现在我们用训练数据来拟合预测器。这里我们设定了一个紧凑的时间预算以便进行快速演示。
from autogluon.multimodal import MultiModalPredictor
predictor = MultiModalPredictor(label=label_col)
predictor.fit(
train_data=train_data,
time_limit=120, # seconds
)
No path specified. Models will be saved in: "AutogluonModels/ag-20250508_205717"
=================== System Info ===================
AutoGluon Version: 1.3.1b20250508
Python Version: 3.11.9
Operating System: Linux
Platform Machine: x86_64
Platform Version: #1 SMP Wed Mar 12 14:53:59 UTC 2025
CPU Count: 8
Pytorch Version: 2.6.0+cu124
CUDA Version: 12.4
Memory Avail: 28.40 GB / 30.95 GB (91.8%)
Disk Space Avail: 185.23 GB / 255.99 GB (72.4%)
===================================================
AutoGluon infers your prediction problem is: 'binary' (because only two unique label-values observed).
2 unique label values: [np.int64(0), np.int64(1)]
If 'binary' is not the correct problem_type, please manually specify the problem_type parameter during Predictor init (You may specify problem_type as one of: ['binary', 'multiclass', 'regression', 'quantile'])
AutoMM starts to create your model. ✨✨✨
To track the learning progress, you can open a terminal and launch Tensorboard:
```shell
# Assume you have installed tensorboard
tensorboard --logdir /home/ci/autogluon/docs/tutorials/multimodal/multimodal_prediction/AutogluonModels/ag-20250508_205717
```
INFO: Seed set to 0
GPU Count: 1
GPU Count to be Used: 1
INFO: Using 16bit Automatic Mixed Precision (AMP)
INFO: GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO: HPU available: False, using: 0 HPUs
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:
| Name | Type | Params | Mode
------------------------------------------------------------------
0 | model | MultimodalFusionMLP | 207 M | train
1 | validation_metric | BinaryAUROC | 0 | train
2 | loss_func | CrossEntropyLoss | 0 | train
------------------------------------------------------------------
207 M Trainable params
0 Non-trainable params
207 M Total params
828.307 Total estimated model params size (MB)
1171 Modules in train mode
0 Modules in eval mode
INFO: Epoch 0, global step 1: 'val_roc_auc' reached 0.56250 (best 0.56250), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/multimodal_prediction/AutogluonModels/ag-20250508_205717/epoch=0-step=1.ckpt' as top 3
INFO: Epoch 0, global step 4: 'val_roc_auc' reached 0.76083 (best 0.76083), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/multimodal_prediction/AutogluonModels/ag-20250508_205717/epoch=0-step=4.ckpt' as top 3
INFO: Epoch 1, global step 5: 'val_roc_auc' reached 0.78056 (best 0.78056), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/multimodal_prediction/AutogluonModels/ag-20250508_205717/epoch=1-step=5.ckpt' as top 3
INFO: Epoch 1, global step 8: 'val_roc_auc' reached 0.79000 (best 0.79000), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/multimodal_prediction/AutogluonModels/ag-20250508_205717/epoch=1-step=8.ckpt' as top 3
INFO: Epoch 2, global step 9: 'val_roc_auc' reached 0.78972 (best 0.79000), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/multimodal_prediction/AutogluonModels/ag-20250508_205717/epoch=2-step=9.ckpt' as top 3
INFO: Time limit reached. Elapsed time is 0:02:18. Signaling Trainer to stop.
Start to fuse 3 checkpoints via the greedy soup algorithm.
INFO: Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
INFO: Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
INFO: Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
AutoMM has created your model. 🎉🎉🎉
To load the model, use the code below:
```python
from autogluon.multimodal import MultiModalPredictor
predictor = MultiModalPredictor.load("/home/ci/autogluon/docs/tutorials/multimodal/multimodal_prediction/AutogluonModels/ag-20250508_205717")
```
If you are not satisfied with the model, try to increase the training time,
adjust the hyperparameters (https://autogluon.cn/stable/tutorials/multimodal/advanced_topics/customization.html),
or post issues on GitHub (https://github.com/autogluon/autogluon/issues).
<autogluon.multimodal.predictor.MultiModalPredictor at 0x7f913c137a50>
在底层,AutoMM 自动推断问题类型(分类或回归),检测数据模态,从多模态模型池中选择相关模型,并训练所选模型。如果存在多个骨干网络,AutoMM 会在它们之上添加一个后期融合模型(MLP 或 transformer)。
评估¶
然后我们可以在测试数据上评估预测器。
scores = predictor.evaluate(test_data, metrics=["roc_auc"])
scores
INFO: Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
{'roc_auc': np.float64(0.9004)}
预测¶
给定一个没有标签列的多模态 dataframe,我们可以预测标签。
predictions = predictor.predict(test_data.drop(columns=label_col))
predictions[:5]
INFO: Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
8 1
70 1
82 1
28 0
63 1
Name: AdoptionSpeed, dtype: int64
对于分类任务,我们可以获得所有类别的概率。
probas = predictor.predict_proba(test_data.drop(columns=label_col))
probas[:5]
INFO: Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
0 | 1 | |
---|---|---|
8 | 0.372848 | 0.627152 |
70 | 0.197599 | 0.802401 |
82 | 0.018618 | 0.981382 |
28 | 0.796530 | 0.203470 |
63 | 0.165020 | 0.834980 |
请注意,对回归任务调用 .predict_proba()
将会抛出异常。
提取嵌入¶
提取嵌入在许多情况下也非常有用,我们可以将每个样本(dataframe 中的每一行)转换为一个嵌入向量。
embeddings = predictor.extract_embedding(test_data.drop(columns=label_col))
embeddings.shape
INFO: Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
(100, 128)
保存和加载¶
保存预测器并重新加载也很方便。
警告
MultiModalPredictor.load()
隐式使用 pickle
模块,该模块已知不安全。可以构造恶意 pickle 数据,在反序列化过程中执行任意代码。切勿加载可能来自不受信任源或可能被篡改的数据。只加载您信任的数据。
import uuid
model_path = f"./tmp/{uuid.uuid4().hex}-saved_model"
predictor.save(model_path)
loaded_predictor = MultiModalPredictor.load(model_path)
scores2 = loaded_predictor.evaluate(test_data, metrics=["roc_auc"])
scores2
Load pretrained checkpoint: /home/ci/autogluon/docs/tutorials/multimodal/multimodal_prediction/tmp/67db6a334113439b8d1da942caa785ec-saved_model/model.ckpt
INFO: Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
{'roc_auc': np.float64(0.9004000000000001)}
其他示例¶
您可以访问 AutoMM 示例 来探索 AutoMM 的其他示例。
自定义¶
要了解如何自定义 AutoMM,请参阅 自定义 AutoMM。