AutoMM 文本+表格 - 快速入门

Open In Colab Open In SageMaker Studio Lab

在许多应用中,文本数据可能与数值/类别数据混合。AutoGluon 的 MultiModalPredictor 可以训练一个神经网络,该网络联合处理多种特征类型,包括文本、类别和数值列。其基本思想是分别嵌入文本、类别和数值字段,然后跨模态融合这些特征。本教程演示了这样一个应用。

import numpy as np
import pandas as pd
import warnings
import os

warnings.filterwarnings('ignore')
np.random.seed(123)
!python3 -m pip install openpyxl
Collecting openpyxl
  Downloading openpyxl-3.1.5-py2.py3-none-any.whl.metadata (2.5 kB)
Collecting et-xmlfile (from openpyxl)
  Downloading et_xmlfile-2.0.0-py3-none-any.whl.metadata (2.7 kB)
Downloading openpyxl-3.1.5-py2.py3-none-any.whl (250 kB)
Downloading et_xmlfile-2.0.0-py3-none-any.whl (18 kB)
Installing collected packages: et-xmlfile, openpyxl
?25l
   ━━━━━━━━━━━━━━━━━━━━╺━━━━━━━━━━━━━━━━━━━ 1/2 [openpyxl]
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2/2 [openpyxl]

Successfully installed et-xmlfile-2.0.0 openpyxl-3.1.5

图书价格预测数据

作为演示,我们使用来自 MachineHack 图书价格预测编程马拉松 的图书价格预测数据集。我们的目标是根据各种特征(如图书作者、摘要、评分等)来预测图书价格。

!mkdir -p price_of_books
!wget https://automl-mm-bench.s3.amazonaws.com/machine_hack_competitions/predict_the_price_of_books/Data.zip -O price_of_books/Data.zip
!cd price_of_books && unzip -o Data.zip
!ls price_of_books/Participants_Data
--2025-05-08 21:04:40--  https://automl-mm-bench.s3.amazonaws.com/machine_hack_competitions/predict_the_price_of_books/Data.zip
Resolving automl-mm-bench.s3.amazonaws.com (automl-mm-bench.s3.amazonaws.com)... 16.15.194.185, 52.216.58.65, 3.5.25.55, ...
Connecting to automl-mm-bench.s3.amazonaws.com (automl-mm-bench.s3.amazonaws.com)|16.15.194.185|:443... connected.
HTTP request sent, awaiting response...
200 OK
Length: 3521673 (3.4M) [application/zip]
Saving to: ‘price_of_books/Data.zip’

price_of_books/Data   0%[                    ]       0  --.-KB/s
price_of_books/Data 100%[===================>]   3.36M  --.-KB/s    in 0.05s   

2025-05-08 21:04:40 (63.6 MB/s) - ‘price_of_books/Data.zip’ saved [3521673/3521673]
Archive:  Data.zip
  inflating: Participants_Data/Data_Test.xlsx  
  inflating: Participants_Data/Data_Train.xlsx  
  inflating: Participants_Data/Sample_Submission.xlsx
Data_Test.xlsx	Data_Train.xlsx  Sample_Submission.xlsx
train_df = pd.read_excel(os.path.join('price_of_books', 'Participants_Data', 'Data_Train.xlsx'), engine='openpyxl')
train_df.head()
标题 作者 版本 评论 评分 内容提要 类型 图书类别 价格
0 The Prisoner's Gold (The Hunters 3) Chris Kuzneski 平装本,– 2016年3月10日 4.0 星(共5星) 8 条客户评论 THE HUNTERS 在他们的第三部精彩小说中回归... 动作 & 冒险 (图书) 动作 & 冒险 220.00
1 Guru Dutt: A Tragedy in Three Acts Arun Khopkar 平装本,– 2012年11月7日 3.9 星(共5星) 14 条客户评论 一位陷入困境的天才的多层次写照,为他... 电影 & 广播 (图书) 传记、日记 & 真实记事 202.93
2 Leviathan (Penguin Classics) Thomas Hobbes 平装本,– 1982年2月25日 4.8 星(共5星) 6 条客户评论 “在人们没有共同权力的时候生存...” 国际关系 幽默 299.00
3 A Pocket Full of Rye (Miss Marple) Agatha Christie 平装本,– 2017年10月5日 4.1 星(共5星) 13 条客户评论 在一个人的口袋里发现了一把谷物... 当代小说 (图书) 犯罪、惊悚 & 悬疑 180.00
4 LIFE 70 Years of Extraordinary Photography Life 杂志编辑 精装本,– 2006年10月10日 5.0 星(共5星) 1 条客户评论 七十年来,“Life” 一直令人激动地... 摄影教科书 艺术、电影 & 摄影 965.62

我们进行一些基本预处理,将数据表中的 ReviewsRatings 转换为数值,并将价格转换为对数尺度。

def preprocess(df):
    df = df.copy(deep=True)
    df.loc[:, 'Reviews'] = pd.to_numeric(df['Reviews'].apply(lambda ele: ele[:-len(' out of 5 stars')]))
    df.loc[:, 'Ratings'] = pd.to_numeric(df['Ratings'].apply(lambda ele: ele.replace(',', '')[:-len(' customer reviews')]))
    df.loc[:, 'Price'] = np.log(df['Price'] + 1)
    return df
train_subsample_size = 1500  # subsample for faster demo, you can try setting to larger values
test_subsample_size = 5
train_df = preprocess(train_df)
train_data = train_df.iloc[100:].sample(train_subsample_size, random_state=123)
test_data = train_df.iloc[:100].sample(test_subsample_size, random_state=245)
train_data.head()
标题 作者 版本 评论 评分 内容提要 类型 图书类别 价格
949 Furious Hours Casey Cep 平装本,– 2019年6月1日 4.0 NaN ‘我已经很久没有拿起书了... 真实记事 (图书) 传记、日记 & 真实记事 5.743003
5504 REST API Design Rulebook Mark Masse 平装本,– 2011年11月7日 5.0 NaN 在当今市场,竞争对手的 Web 服务... 计算、互联网 & 数字媒体 (图书) 计算、互联网 & 数字媒体 5.786897
5856 The Atlantropa Articles: A Novel Cody Franklin 平装本,– 进口, 2018年11月1日 4.5 2.0 #1 亚马逊畅销书!反乌托邦架空历史... 动作 & 冒险 (图书) 浪漫 6.893656
4137 Hickory Dickory Dock (Poirot) Agatha Christie 平装本,– 2017年10月5日 4.3 21.0 伦敦不仅仅发生着小偷小摸... 动作 & 冒险 (图书) 犯罪、惊悚 & 悬疑 5.192957
3205 The Stanley Kubrick Archives (Bibliotheca Univ... Alison Castle 精装本,– 2016年8月21日 4.6 3.0 1968年,当斯坦利·库布里克被要求... 电影 & 广播 (图书) 幽默 6.889591

训练

我们可以简单地创建一个 MultiModalPredictor 并调用 predictor.fit() 来训练一个可以在所有类型的特征上操作的模型。在内部,神经网络将根据每个特征列推断的数据类型自动生成。为了节省时间,我们对数据进行子采样,并且只训练三分钟。

from autogluon.multimodal import MultiModalPredictor
import uuid

time_limit = 3 * 60  # set to larger value in your applications
model_path = f"./tmp/{uuid.uuid4().hex}-automm_text_book_price_prediction"
predictor = MultiModalPredictor(label='Price', path=model_path)
predictor.fit(train_data, time_limit=time_limit)
=================== System Info ===================
AutoGluon Version:  1.3.1b20250508
Python Version:     3.11.9
Operating System:   Linux
Platform Machine:   x86_64
Platform Version:   #1 SMP Wed Mar 12 14:53:59 UTC 2025
CPU Count:          8
Pytorch Version:    2.6.0+cu124
CUDA Version:       12.4
Memory Avail:       28.39 GB / 30.95 GB (91.7%)
Disk Space Avail:   179.84 GB / 255.99 GB (70.3%)
===================================================
AutoGluon infers your prediction problem is: 'regression' (because dtype of label-column == float and many unique label-values observed).
	Label info (max, min, mean, stddev): (9.115699967822062, 3.6109179126442243, 6.02567, 0.7694)
	If 'regression' is not the correct problem_type, please manually specify the problem_type parameter during Predictor init (You may specify problem_type as one of: ['binary', 'multiclass', 'regression', 'quantile'])

AutoMM starts to create your model. ✨✨✨

To track the learning progress, you can open a terminal and launch Tensorboard:
    ```shell
    # Assume you have installed tensorboard
    tensorboard --logdir /home/ci/autogluon/docs/tutorials/multimodal/multimodal_prediction/tmp/db542cc5848b4a0fb688e5b90f23c624-automm_text_book_price_prediction
    ```
Seed set to 0
GPU Count: 1
GPU Count to be Used: 1
Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name              | Type                | Params | Mode 
------------------------------------------------------------------
0 | model             | MultimodalFusionMLP | 110 M  | train
1 | validation_metric | MeanSquaredError    | 0      | train
2 | loss_func         | MSELoss             | 0      | train
------------------------------------------------------------------
110 M     Trainable params
0         Non-trainable params
110 M     Total params
442.755   Total estimated model params size (MB)
309       Modules in train mode
0         Modules in eval mode
Epoch 0, global step 4: 'val_rmse' reached 1.20494 (best 1.20494), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/multimodal_prediction/tmp/db542cc5848b4a0fb688e5b90f23c624-automm_text_book_price_prediction/epoch=0-step=4.ckpt' as top 3
Epoch 0, global step 10: 'val_rmse' reached 1.31910 (best 1.20494), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/multimodal_prediction/tmp/db542cc5848b4a0fb688e5b90f23c624-automm_text_book_price_prediction/epoch=0-step=10.ckpt' as top 3
Epoch 1, global step 14: 'val_rmse' reached 0.97402 (best 0.97402), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/multimodal_prediction/tmp/db542cc5848b4a0fb688e5b90f23c624-automm_text_book_price_prediction/epoch=1-step=14.ckpt' as top 3
Epoch 1, global step 20: 'val_rmse' reached 0.89857 (best 0.89857), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/multimodal_prediction/tmp/db542cc5848b4a0fb688e5b90f23c624-automm_text_book_price_prediction/epoch=1-step=20.ckpt' as top 3
Epoch 2, global step 24: 'val_rmse' reached 0.92035 (best 0.89857), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/multimodal_prediction/tmp/db542cc5848b4a0fb688e5b90f23c624-automm_text_book_price_prediction/epoch=2-step=24.ckpt' as top 3
Epoch 2, global step 30: 'val_rmse' reached 0.90502 (best 0.89857), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/multimodal_prediction/tmp/db542cc5848b4a0fb688e5b90f23c624-automm_text_book_price_prediction/epoch=2-step=30.ckpt' as top 3
Time limit reached. Elapsed time is 0:03:00. Signaling Trainer to stop.
Epoch 3, global step 33: 'val_rmse' reached 0.87321 (best 0.87321), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/multimodal_prediction/tmp/db542cc5848b4a0fb688e5b90f23c624-automm_text_book_price_prediction/epoch=3-step=33.ckpt' as top 3
Start to fuse 3 checkpoints via the greedy soup algorithm.
Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
AutoMM has created your model. 🎉🎉🎉

To load the model, use the code below:
    ```python
    from autogluon.multimodal import MultiModalPredictor
    predictor = MultiModalPredictor.load("/home/ci/autogluon/docs/tutorials/multimodal/multimodal_prediction/tmp/db542cc5848b4a0fb688e5b90f23c624-automm_text_book_price_prediction")
    ```

If you are not satisfied with the model, try to increase the training time, 
adjust the hyperparameters (https://autogluon.cn/stable/tutorials/multimodal/advanced_topics/customization.html),
or post issues on GitHub (https://github.com/autogluon/autogluon/issues).
<autogluon.multimodal.predictor.MultiModalPredictor at 0x7f5698309a10>

预测

我们可以使用 MultiModalPredictor 轻松获得预测并提取数据嵌入。

predictions = predictor.predict(test_data)
print('Predictions:')
print('------------')
print(np.exp(predictions) - 1)
print()
print('True Value:')
print('------------')
print(np.exp(test_data['Price']) - 1)
Predictions:
------------
1     369.945160
31    361.581726
19    691.173340
45    434.631622
82    505.455048
Name: Price, dtype: float32

True Value:
------------
1     202.93
31    799.00
19    352.00
45    395.10
82    409.00
Name: Price, dtype: float64
Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
performance = predictor.evaluate(test_data)
print(performance)
{'root_mean_squared_error': np.float64(0.5461323272699485)}
Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
embeddings = predictor.extract_embedding(test_data)
embeddings.shape
Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
(5, 128)

其他示例

您可以访问 AutoMM Examples 来探索 AutoMM 的其他示例。

自定义

要了解如何自定义 AutoMM,请参考 自定义 AutoMM