使用 AutoMM 进行持续训练

Open In Colab Open In SageMaker Studio Lab

持续训练提供了一种机器学习模型随着时间推移改进性能的方法。它使模型能够在先前获得的知识基础上进行学习,从而提高准确性,促进跨任务的知识迁移,并节省计算资源。在本教程中,我们将演示 AutoMM 持续训练的三个用例。

用例 1:增加数据或训练时间来扩展训练

有时,如果模型出现欠拟合,可以通过增加训练周期或额外的训练时间来改善性能。使用 AutoMM,您可以轻松地延长模型的训练时间,而无需从头开始。

此外,通常还需要将更多数据纳入模型。如果这是一个多类别问题,AutoMM 允许您使用相同问题类型和相同类别的数据继续训练。这种灵活性使得随着数据的增长,轻松改进和调整模型成为可能。

我们以 斯坦福情感树库 (SST) 数据集为例。该数据集包含电影评论及其相关情感。给定一条新的电影评论,目标是预测文本中反映的情感(在本例中是二元分类,如果评论表达积极观点则标记为 1,否则标记为 0)。首先加载并查看数据,注意标签存储在一个名为 label 的列中。

from autogluon.core.utils.loaders import load_pd

train_data = load_pd.load("https://autogluon-text.s3-accelerate.amazonaws.com/glue/sst/train.parquet")
test_data = load_pd.load("https://autogluon-text.s3-accelerate.amazonaws.com/glue/sst/dev.parquet")
subsample_size = 1000  # subsample data for faster demo, try setting this to larger values
train_data_1 = train_data.sample(n=subsample_size, random_state=0)
train_data_1.head(10)
句子 标签
43787 在其最精彩的时刻令人非常愉悦 1
16159 ,美式奶茶足以让你收起... 0
59015 太像拉姆·达斯 (ram dass) 的电视购物广告... 0
5108 一段激动人心的视觉片段 1
67052 酷炫的视觉倒带 1
35938 坚硬的地面 0
49879 这个引人注目、安静脆弱的个性... 1
51591 潘·纳林 (pan nalin) 的解说美丽而神秘... 1
56780 奇妙地古怪 1
28518 最美妙、最具启发性 1

现在让我们来训练模型。为了确保本教程快速运行,我们只使用 1000 个训练样本的子集调用 fit(),并将运行时长限制在大约 1 分钟。要在您的应用程序中获得合理的性能,建议设置更长的 time_limit(例如 1 小时),或者根本不指定 time_limit (time_limit=None)。

from autogluon.multimodal import MultiModalPredictor
import uuid

model_path = f"./tmp/{uuid.uuid4().hex}-automm_sst"
predictor = MultiModalPredictor(label="label", eval_metric="acc", path=model_path)
predictor.fit(train_data_1, time_limit=60)
=================== System Info ===================
AutoGluon Version:  1.3.1b20250508
Python Version:     3.11.9
Operating System:   Linux
Platform Machine:   x86_64
Platform Version:   #1 SMP Wed Mar 12 14:53:59 UTC 2025
CPU Count:          8
Pytorch Version:    2.6.0+cu124
CUDA Version:       12.4
Memory Avail:       28.35 GB / 30.95 GB (91.6%)
Disk Space Avail:   185.24 GB / 255.99 GB (72.4%)
===================================================
AutoGluon infers your prediction problem is: 'binary' (because only two unique label-values observed).
	2 unique label values:  [np.int64(1), np.int64(0)]
	If 'binary' is not the correct problem_type, please manually specify the problem_type parameter during Predictor init (You may specify problem_type as one of: ['binary', 'multiclass', 'regression', 'quantile'])

AutoMM starts to create your model. ✨✨✨

To track the learning progress, you can open a terminal and launch Tensorboard:
    ```shell
    # Assume you have installed tensorboard
    tensorboard --logdir /home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/tmp/7b64a345db624baea7da1c66fa581306-automm_sst
    ```
Seed set to 0
GPU Count: 1
GPU Count to be Used: 1
Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name              | Type                         | Params | Mode 
---------------------------------------------------------------------------
0 | model             | HFAutoModelForTextPrediction | 108 M  | train
1 | validation_metric | MulticlassAccuracy           | 0      | train
2 | loss_func         | CrossEntropyLoss             | 0      | train
---------------------------------------------------------------------------
108 M     Trainable params
0         Non-trainable params
108 M     Total params
435.573   Total estimated model params size (MB)
229       Modules in train mode
0         Modules in eval mode
Epoch 0, global step 3: 'val_accuracy' reached 0.55500 (best 0.55500), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/tmp/7b64a345db624baea7da1c66fa581306-automm_sst/epoch=0-step=3.ckpt' as top 3
Epoch 0, global step 7: 'val_accuracy' reached 0.59500 (best 0.59500), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/tmp/7b64a345db624baea7da1c66fa581306-automm_sst/epoch=0-step=7.ckpt' as top 3
Epoch 1, global step 10: 'val_accuracy' reached 0.63000 (best 0.63000), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/tmp/7b64a345db624baea7da1c66fa581306-automm_sst/epoch=1-step=10.ckpt' as top 3
Epoch 1, global step 14: 'val_accuracy' reached 0.72000 (best 0.72000), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/tmp/7b64a345db624baea7da1c66fa581306-automm_sst/epoch=1-step=14.ckpt' as top 3
Epoch 2, global step 17: 'val_accuracy' reached 0.86500 (best 0.86500), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/tmp/7b64a345db624baea7da1c66fa581306-automm_sst/epoch=2-step=17.ckpt' as top 3
Epoch 2, global step 21: 'val_accuracy' reached 0.80500 (best 0.86500), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/tmp/7b64a345db624baea7da1c66fa581306-automm_sst/epoch=2-step=21.ckpt' as top 3
Time limit reached. Elapsed time is 0:01:00. Signaling Trainer to stop.
Epoch 3, global step 23: 'val_accuracy' reached 0.86000 (best 0.86500), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/tmp/7b64a345db624baea7da1c66fa581306-automm_sst/epoch=3-step=23.ckpt' as top 3
Start to fuse 3 checkpoints via the greedy soup algorithm.
Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
AutoMM has created your model. 🎉🎉🎉

To load the model, use the code below:
    ```python
    from autogluon.multimodal import MultiModalPredictor
    predictor = MultiModalPredictor.load("/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/tmp/7b64a345db624baea7da1c66fa581306-automm_sst")
    ```

If you are not satisfied with the model, try to increase the training time, 
adjust the hyperparameters (https://autogluon.cn/stable/tutorials/multimodal/advanced_topics/customization.html),
or post issues on GitHub (https://github.com/autogluon/autogluon/issues).
<autogluon.multimodal.predictor.MultiModalPredictor at 0x7ff2ac246d50>

训练完成后,我们可以使用与训练数据格式相似的独立测试数据来评估我们的预测器。

test_score = predictor.evaluate(test_data)
print(test_score)
{'accuracy': 0.8738532110091743}
Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.

如果训练成功完成,model.ckpt 文件将位于 model_path 下。如果您认为模型仍然欠拟合,只需使用相同的数据再次运行 .fit(),即可从此检查点继续训练。如果您有一些新数据需要添加,并且不想从头开始训练,也可以使用新的组合数据集运行 .fit()

predictor_2 = MultiModalPredictor.load(model_path)  # you can also use the `predictor` we assigned above
train_data_2 = train_data.drop(train_data_1.index).sample(n=subsample_size, random_state=0)
predictor_2.fit(train_data_2, time_limit=60)
Load pretrained checkpoint: /home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/tmp/7b64a345db624baea7da1c66fa581306-automm_sst/model.ckpt
A new predictor save path is created. This is to prevent you to overwrite previous predictor saved here. You could check current save path at predictor._save_path. If you still want to use this path, set resume=True
No path specified. Models will be saved in: "AutogluonModels/ag-20250508_205402"
=================== System Info ===================
AutoGluon Version:  1.3.1b20250508
Python Version:     3.11.9
Operating System:   Linux
Platform Machine:   x86_64
Platform Version:   #1 SMP Wed Mar 12 14:53:59 UTC 2025
CPU Count:          8
Pytorch Version:    2.6.0+cu124
CUDA Version:       12.4
Memory Avail:       24.88 GB / 30.95 GB (80.4%)
Disk Space Avail:   184.04 GB / 255.99 GB (71.9%)
===================================================

AutoMM starts to create your model. ✨✨✨

To track the learning progress, you can open a terminal and launch Tensorboard:
    ```shell
    # Assume you have installed tensorboard
    tensorboard --logdir /home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20250508_205402
    ```
Seed set to 0
GPU Count: 1
GPU Count to be Used: 1
Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name              | Type                         | Params | Mode 
---------------------------------------------------------------------------
0 | model             | HFAutoModelForTextPrediction | 108 M  | train
1 | validation_metric | MulticlassAccuracy           | 0      | train
2 | loss_func         | CrossEntropyLoss             | 0      | train
---------------------------------------------------------------------------
108 M     Trainable params
0         Non-trainable params
108 M     Total params
435.573   Total estimated model params size (MB)
229       Modules in train mode
0         Modules in eval mode
Epoch 0, global step 3: 'val_accuracy' reached 0.85000 (best 0.85000), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20250508_205402/epoch=0-step=3.ckpt' as top 3
Epoch 0, global step 7: 'val_accuracy' reached 0.84500 (best 0.85000), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20250508_205402/epoch=0-step=7.ckpt' as top 3
Epoch 1, global step 10: 'val_accuracy' reached 0.84500 (best 0.85000), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20250508_205402/epoch=1-step=10.ckpt' as top 3
Epoch 1, global step 14: 'val_accuracy' reached 0.88000 (best 0.88000), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20250508_205402/epoch=1-step=14.ckpt' as top 3
Epoch 2, global step 17: 'val_accuracy' reached 0.90000 (best 0.90000), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20250508_205402/epoch=2-step=17.ckpt' as top 3
Epoch 2, global step 21: 'val_accuracy' reached 0.90000 (best 0.90000), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20250508_205402/epoch=2-step=21.ckpt' as top 3
Time limit reached. Elapsed time is 0:01:00. Signaling Trainer to stop.
Epoch 3, global step 23: 'val_accuracy' reached 0.94000 (best 0.94000), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20250508_205402/epoch=3-step=23.ckpt' as top 3
Start to fuse 3 checkpoints via the greedy soup algorithm.
Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
AutoMM has created your model. 🎉🎉🎉

To load the model, use the code below:
    ```python
    from autogluon.multimodal import MultiModalPredictor
    predictor = MultiModalPredictor.load("/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20250508_205402")
    ```

If you are not satisfied with the model, try to increase the training time, 
adjust the hyperparameters (https://autogluon.cn/stable/tutorials/multimodal/advanced_topics/customization.html),
or post issues on GitHub (https://github.com/autogluon/autogluon/issues).
<autogluon.multimodal.predictor.MultiModalPredictor at 0x7ff14810ea10>
test_score_2 = predictor_2.evaluate(test_data)
print(test_score_2)
Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
{'accuracy': 0.8990825688073395}

用例 2:从上次检查点恢复训练

如果您的训练过程因某种原因崩溃,AutoMM 允许您从上次中断的地方恢复训练。last.ckpt 文件将保存在 model_path 下,而不是 model.ckpt。要恢复训练,只需调用 MultiModalPredictor.load() 并指定 resume 选项即可。

predictor_resume = MultiModalPredictor.load(path=model_path, resume=True)
predictor.fit(train_data, time_limit=60)

用例 3:将预训练模型应用于新任务

通常,您会遇到新任务与您之前训练过模型的任务相关但不完全相同的情况(例如,训练一个更细粒度的情感分析模型,或向您的多类别模型添加更多类别)。如果您希望利用模型已经从旧数据中学到的知识来帮助它更快、更有效地学习新任务,AutoMM 支持将您训练好的模型导出为模型权重,并将其用作基础模型。

dump_model_path = f"./tmp/{uuid.uuid4().hex}-automm_sst"
predictor.dump_model(save_path=dump_model_path)
Model weights and tokenizer for hf_text are saved to ./tmp/bace81e0abe84bfe8590e1085c468948-automm_sst/hf_text.
'./tmp/bace81e0abe84bfe8590e1085c468948-automm_sst'

然后,您可以加载已训练模型的权重,并在新数据上继续训练/微调模型。

这里有一个例子,将我们之前训练的二元文本模型用于回归任务。我们使用 语义文本相似度基准数据集 仅作说明,因此您可能希望将此功能应用于更相关的数据集。在该数据中,名为 score 的列包含数值(这是我们希望预测的值),这些数值是人工标注的每对给定句子的相似度得分。

sts_train_data = load_pd.load("https://autogluon-text.s3-accelerate.amazonaws.com/glue/sts/train.parquet")[
    ["sentence1", "sentence2", "score"]
]
sts_test_data = load_pd.load("https://autogluon-text.s3-accelerate.amazonaws.com/glue/sts/dev.parquet")[
    ["sentence1", "sentence2", "score"]
]
sts_train_data.head(10)
Loaded data from: https://autogluon-text.s3-accelerate.amazonaws.com/glue/sts/train.parquet | Columns = 4 / 4 | Rows = 5749 -> 5749
Loaded data from: https://autogluon-text.s3-accelerate.amazonaws.com/glue/sts/dev.parquet | Columns = 4 / 4 | Rows = 1500 -> 1500
句子1 句子2 得分
0 一架飞机正在起飞。 一架飞机正在起飞。 5.00
1 一个男人正在吹一支大长笛。 一个男人正在吹一支长笛。 3.80
2 一个男人正在披萨上撒碎奶酪。 一个男人正在未烤的... 上撒碎奶酪。 3.80
3 三个男人正在下棋。 两个男人正在下棋。 2.60
4 一个男人正在拉大提琴。 一个坐着的男人正在拉大提琴。 4.25
5 一些男人正在打架。 两个男人正在打架。 4.25
6 一个男人正在吸烟。 一个男人正在滑冰。 0.50
7 那个男人正在弹钢琴。 那个男人正在弹吉他。 1.60
8 一个男人正在弹吉他并唱歌。 一个女人正在弹木吉他并唱歌... 2.20
9 一个人正在把一只猫扔到天花板上。 一个人把一只猫扔到天花板上。 5.00

要指定您创建的自定义模型,请在 .fit() 中使用 hyperparameters 选项。

hyperparameters={
    "model.hf_text.checkpoint_name": dump_model_path
}
sts_model_path = f"./tmp/{uuid.uuid4().hex}-automm_sts"
predictor_sts = MultiModalPredictor(label="score", path=sts_model_path)
predictor_sts.fit(
    sts_train_data, hyperparameters={"model.hf_text.checkpoint_name": f"{dump_model_path}/hf_text"}, time_limit=30
)
=================== System Info ===================
AutoGluon Version:  1.3.1b20250508
Python Version:     3.11.9
Operating System:   Linux
Platform Machine:   x86_64
Platform Version:   #1 SMP Wed Mar 12 14:53:59 UTC 2025
CPU Count:          8
Pytorch Version:    2.6.0+cu124
CUDA Version:       12.4
Memory Avail:       24.28 GB / 30.95 GB (78.5%)
Disk Space Avail:   183.22 GB / 255.99 GB (71.6%)
===================================================
AutoGluon infers your prediction problem is: 'regression' (because dtype of label-column == float and label-values can't be converted to int).
	Label info (max, min, mean, stddev): (5.0, 0.0, 2.701, 1.4644)
	If 'regression' is not the correct problem_type, please manually specify the problem_type parameter during Predictor init (You may specify problem_type as one of: ['binary', 'multiclass', 'regression', 'quantile'])

AutoMM starts to create your model. ✨✨✨

To track the learning progress, you can open a terminal and launch Tensorboard:
    ```shell
    # Assume you have installed tensorboard
    tensorboard --logdir /home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/tmp/2076ef015df640deb9a50cdf5287d53b-automm_sts
    ```
Seed set to 0
GPU Count: 1
GPU Count to be Used: 1
Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name              | Type                         | Params | Mode 
---------------------------------------------------------------------------
0 | model             | HFAutoModelForTextPrediction | 108 M  | train
1 | validation_metric | MeanSquaredError             | 0      | train
2 | loss_func         | MSELoss                      | 0      | train
---------------------------------------------------------------------------
108 M     Trainable params
0         Non-trainable params
108 M     Total params
435.570   Total estimated model params size (MB)
229       Modules in train mode
0         Modules in eval mode
Epoch 0, global step 20: 'val_rmse' reached 0.64862 (best 0.64862), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/tmp/2076ef015df640deb9a50cdf5287d53b-automm_sts/epoch=0-step=20.ckpt' as top 3
Time limit reached. Elapsed time is 0:00:30. Signaling Trainer to stop.
Epoch 0, global step 31: 'val_rmse' reached 0.56736 (best 0.56736), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/tmp/2076ef015df640deb9a50cdf5287d53b-automm_sts/epoch=0-step=31.ckpt' as top 3
Start to fuse 2 checkpoints via the greedy soup algorithm.
Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
AutoMM has created your model. 🎉🎉🎉

To load the model, use the code below:
    ```python
    from autogluon.multimodal import MultiModalPredictor
    predictor = MultiModalPredictor.load("/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/tmp/2076ef015df640deb9a50cdf5287d53b-automm_sts")
    ```

If you are not satisfied with the model, try to increase the training time, 
adjust the hyperparameters (https://autogluon.cn/stable/tutorials/multimodal/advanced_topics/customization.html),
or post issues on GitHub (https://github.com/autogluon/autogluon/issues).
<autogluon.multimodal.predictor.MultiModalPredictor at 0x7ff2893670d0>
test_score = predictor_sts.evaluate(sts_test_data, metrics=["rmse", "pearsonr", "spearmanr"])
print("RMSE = {:.2f}".format(test_score["rmse"]))
print("PEARSONR = {:.4f}".format(test_score["pearsonr"]))
print("SPEARMANR = {:.4f}".format(test_score["spearmanr"]))
RMSE = 0.92
PEARSONR = 0.8050
SPEARMANR = 0.8045
Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.

我们目前支持导出 timm 图像模型、MMDetection 图像模型、HuggingFace 文本模型以及包含上述任何模型的融合模型。同样,我们也可以加载自定义训练的 timm 图像模型,使用

{"model.timm_image.checkpoint_name": timm_image_model_path}

以及自定义训练的 MMDetection 模型,使用

{"model.mmdet_image.checkpoint_name": mmdet_image_model_path}

此功能可帮助您将之前训练任务中学到的知识应用于新任务,从而节省您的时间和计算资源。本教程不会深入探讨细节,但请记住,我们尚未解决此用例中的一个重大挑战,即灾难性遗忘