通过参数高效微调在单块GPU上训练十亿级模型

Open In Colab Open In SageMaker Studio Lab

正如斯坦福大学以人为中心的智能研究所的一篇最新论文所指出的,随着“基础模型”的兴起,人工智能正在经历一场范式转变,即通常以自监督方式在各种数据集上训练的巨型模型。这些基础模型是AutoMM的关键,可以轻松应用于下游任务。然而,随着这些基础模型规模的增大,微调它们变得越来越困难。下面是摘自微软研究博客的一张图,展示了这一趋势

Scaling of foundation models

AutoMM的目标是帮助任何人通过开源基础模型(包括这些巨型模型)解决机器学习问题。为了微调这些大规模模型,我们采用了最近流行的参数高效微调技术。其思路是微调基础模型中一小部分权重(例如,BitFit),或者在固定的主干网络之上添加一个微小的可调结构(例如,Prompt TuningLoRAAdapterMAM AdapterIA^3)。这些技术可以有效降低峰值内存使用量和模型训练时间,同时保持性能。

在本教程中,我们将介绍如何在 MultiModalPredictor 中应用参数高效微调。首先,我们将介绍如何采用"ia3_bias"算法进行参数高效微调。然后,我们将展示如何简单地结合"ia3_bias"和梯度检查点技术,通过一块 NVIDIA T4 GPU 对 Google 的 FLAN-T5 XL 版本进行微调。

准备数据集

跨语言亚马逊产品评论情感数据集包含四种语言的亚马逊产品评论。在这里,我们加载数据集的英文和德语部分。在标签列中,0 表示负面情感,1 表示正面情感。为了演示目的,我们将训练数据下采样到 1000 个样本。我们将在英文数据集上训练模型,并直接在德语和日语测试集上评估其性能。

!wget --quiet https://automl-mm-bench.s3.amazonaws.com/multilingual-datasets/amazon_review_sentiment_cross_lingual.zip -O amazon_review_sentiment_cross_lingual.zip
!unzip -q -o amazon_review_sentiment_cross_lingual.zip -d .
import os
import shutil
os.environ["TRANSFORMERS_CACHE"] = "cache"

def clear_cache():
    if os.path.exists("cache"):
        shutil.rmtree("cache")

clear_cache()
import pandas as pd
import warnings
warnings.filterwarnings("ignore")

train_en_df = pd.read_csv("amazon_review_sentiment_cross_lingual/en_train.tsv",
                          sep="\t",
                          header=None,
                          names=["label", "text"]) \
                .sample(1000, random_state=123).reset_index(drop=True)

test_en_df = pd.read_csv("amazon_review_sentiment_cross_lingual/en_test.tsv",
                          sep="\t",
                          header=None,
                          names=["label", "text"]) \
               .sample(200, random_state=123).reset_index(drop=True)
test_de_df = pd.read_csv("amazon_review_sentiment_cross_lingual/de_test.tsv",
                          sep="\t", header=None, names=["label", "text"]) \
               .sample(200, random_state=123).reset_index(drop=True)

test_jp_df = pd.read_csv('amazon_review_sentiment_cross_lingual/jp_test.tsv',
                          sep='\t', header=None, names=['label', 'text']) \
               .sample(200, random_state=123).reset_index(drop=True)
train_en_df.head(5)
标签 文本
0 0 这是一部字面上看几乎没有错...
1 0 这段音乐很智能,但不太...
2 0 有史以来录制的最好的摇滚乐作品之一...
3 0 阅读这里发布的评论,就像重温...
4 1 我刚刚读完第 341 页,最后一页。它...
test_jp_df.head(5)
标签 文本
0 1 原作はビクトル・ユーゴの長編小説だが、私が子供の頃読んだのは短縮版の「ああ無情」。それでもこ...
1 1 ほかの作品のレビューにみんな書いているのに、何故この作品について書いている人が一人しかいない...
2 0 一番の問題点は青島が出ていない事でしょう。 TV番組では『芸人が出ていればバラエティだから...
3 0 昔、 りんたろう監督によるアニメ「カムイの剣」があった。 「カムイの剣」…を観た人なら本作...
4 1 以前のアルバムを聴いていないのでなんとも言えないが、クラシックなメタルを聞いてきた耳には、と...

使用 IA3 + BitFit 微调多语言模型

在 AutoMM 中,要启用高效微调,只需将 optim.peft 指定为 "ia3_bias"

from autogluon.multimodal import MultiModalPredictor
import uuid

model_path = f"./tmp/{uuid.uuid4().hex}-multilingual_ia3"
predictor = MultiModalPredictor(label="label",
                                path=model_path)
predictor.fit(train_en_df,
              presets="multilingual",
              hyperparameters={
                  "optim.peft": "ia3_bias",
                  "optim.lr_decay": 0.9,
                  "optim.lr": 3e-03,
                  "optim.end_lr": 3e-03,
                  "optim.max_epochs": 2,
                  "optim.warmup_steps": 0,
                  "env.batch_size": 32,
              })
=================== System Info ===================
AutoGluon Version:  1.3.1b20250508
Python Version:     3.11.9
Operating System:   Linux
Platform Machine:   x86_64
Platform Version:   #1 SMP Wed Mar 12 14:53:59 UTC 2025
CPU Count:          8
Pytorch Version:    2.6.0+cu124
CUDA Version:       12.4
Memory Avail:       28.40 GB / 30.95 GB (91.8%)
Disk Space Avail:   181.97 GB / 255.99 GB (71.1%)
===================================================
AutoGluon infers your prediction problem is: 'binary' (because only two unique label-values observed).
	2 unique label values:  [np.int64(0), np.int64(1)]
	If 'binary' is not the correct problem_type, please manually specify the problem_type parameter during Predictor init (You may specify problem_type as one of: ['binary', 'multiclass', 'regression', 'quantile'])

AutoMM starts to create your model. ✨✨✨

To track the learning progress, you can open a terminal and launch Tensorboard:
    ```shell
    # Assume you have installed tensorboard
    tensorboard --logdir /home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/tmp/dd8a69b4dd8f41d58155c12aaf6ec843-multilingual_ia3
    ```
Seed set to 0
GPU Count: 1
GPU Count to be Used: 1
Using bfloat16 Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name              | Type                         | Params | Mode 
---------------------------------------------------------------------------
0 | model             | HFAutoModelForTextPrediction | 278 M  | train
1 | validation_metric | BinaryAUROC                  | 0      | train
2 | loss_func         | CrossEntropyLoss             | 0      | train
---------------------------------------------------------------------------
122 K     Trainable params
278 M     Non-trainable params
278 M     Total params
1,112.955 Total estimated model params size (MB)
241       Modules in train mode
0         Modules in eval mode
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[6], line 7
      4 model_path = f"./tmp/{uuid.uuid4().hex}-multilingual_ia3"
      5 predictor = MultiModalPredictor(label="label",
      6                                 path=model_path)
----> 7 predictor.fit(train_en_df,
      8               presets="multilingual",
      9               hyperparameters={
     10                   "optim.peft": "ia3_bias",
     11                   "optim.lr_decay": 0.9,
     12                   "optim.lr": 3e-03,
     13                   "optim.end_lr": 3e-03,
     14                   "optim.max_epochs": 2,
     15                   "optim.warmup_steps": 0,
     16                   "env.batch_size": 32,
     17               })

File ~/autogluon/multimodal/src/autogluon/multimodal/predictor.py:540, in MultiModalPredictor.fit(self, train_data, presets, tuning_data, max_num_tuning_data, id_mappings, time_limit, save_path, hyperparameters, column_types, holdout_frac, teacher_predictor, seed, standalone, hyperparameter_tune_kwargs, clean_ckpts, predictions, labels, predictors)
    537     assert isinstance(predictors, list)
    538     learners = [ele if isinstance(ele, str) else ele._learner for ele in predictors]
--> 540 self._learner.fit(
    541     train_data=train_data,
    542     presets=presets,
    543     tuning_data=tuning_data,
    544     max_num_tuning_data=max_num_tuning_data,
    545     time_limit=time_limit,
    546     save_path=save_path,
    547     hyperparameters=hyperparameters,
    548     column_types=column_types,
    549     holdout_frac=holdout_frac,
    550     teacher_learner=teacher_learner,
    551     seed=seed,
    552     standalone=standalone,
    553     hyperparameter_tune_kwargs=hyperparameter_tune_kwargs,
    554     clean_ckpts=clean_ckpts,
    555     id_mappings=id_mappings,
    556     predictions=predictions,
    557     labels=labels,
    558     learners=learners,
    559 )
    561 return self

File ~/autogluon/multimodal/src/autogluon/multimodal/learners/base.py:665, in BaseLearner.fit(self, train_data, presets, tuning_data, time_limit, save_path, hyperparameters, column_types, holdout_frac, teacher_learner, seed, standalone, hyperparameter_tune_kwargs, clean_ckpts, **kwargs)
    658 self.fit_sanity_check()
    659 self.prepare_fit_args(
    660     time_limit=time_limit,
    661     seed=seed,
    662     standalone=standalone,
    663     clean_ckpts=clean_ckpts,
    664 )
--> 665 fit_returns = self.execute_fit()
    666 self.on_fit_end(
    667     training_start=training_start,
    668     strategy=fit_returns.get("strategy", None),
   (...)
    671     clean_ckpts=clean_ckpts,
    672 )
    674 return self

File ~/autogluon/multimodal/src/autogluon/multimodal/learners/base.py:577, in BaseLearner.execute_fit(self)
    575     return dict()
    576 else:
--> 577     attributes = self.fit_per_run(**self._fit_args)
    578     self.update_attributes(**attributes)  # only update attributes for non-HPO mode
    579     return attributes

File ~/autogluon/multimodal/src/autogluon/multimodal/learners/base.py:1358, in BaseLearner.fit_per_run(self, max_time, save_path, ckpt_path, resume, enable_progress_bar, seed, hyperparameters, advanced_hyperparameters, config, df_preprocessor, data_processors, model, standalone, clean_ckpts)
   1339 config = self.post_update_config_per_run(
   1340     config=config,
   1341     num_gpus=num_gpus,
   1342     precision=precision,
   1343     strategy=strategy,
   1344 )
   1345 trainer = self.init_trainer_per_run(
   1346     num_gpus=num_gpus,
   1347     config=config,
   (...)
   1355     enable_progress_bar=enable_progress_bar,
   1356 )
-> 1358 self.run_trainer(
   1359     trainer=trainer,
   1360     litmodule=litmodule,
   1361     datamodule=datamodule,
   1362     ckpt_path=ckpt_path,
   1363     resume=resume,
   1364 )
   1365 self.on_fit_per_run_end(
   1366     save_path=save_path,
   1367     standalone=standalone,
   (...)
   1372     model=model,
   1373 )
   1375 best_score = (
   1376     trainer.callback_metrics[f"val_{self._validation_metric_name}"].item()
   1377     if f"val_{self._validation_metric_name}" in trainer.callback_metrics
   1378     else self._best_score
   1379 )  # https://github.com/autogluon/autogluon/issues/4428

File ~/autogluon/multimodal/src/autogluon/multimodal/learners/base.py:1211, in BaseLearner.run_trainer(self, trainer, litmodule, datamodule, ckpt_path, resume, pred_writer, is_train)
   1209     warnings.filterwarnings("ignore", filter)
   1210 if is_train:
-> 1211     trainer.fit(
   1212         litmodule,
   1213         datamodule=datamodule,
   1214         ckpt_path=ckpt_path if resume else None,  # this is to resume training that was broken accidentally
   1215     )
   1216 else:
   1217     blacklist_msgs = []

File ~/opt/venv/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py:561, in Trainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    559 self.training = True
    560 self.should_stop = False
--> 561 call._call_and_handle_interrupt(
    562     self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
    563 )

File ~/opt/venv/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py:48, in _call_and_handle_interrupt(trainer, trainer_fn, *args, **kwargs)
     46     if trainer.strategy.launcher is not None:
     47         return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
---> 48     return trainer_fn(*args, **kwargs)
     50 except _TunerExitException:
     51     _call_teardown_hook(trainer)

File ~/opt/venv/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py:599, in Trainer._fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    592     download_model_from_registry(ckpt_path, self)
    593 ckpt_path = self._checkpoint_connector._select_ckpt_path(
    594     self.state.fn,
    595     ckpt_path,
    596     model_provided=True,
    597     model_connected=self.lightning_module is not None,
    598 )
--> 599 self._run(model, ckpt_path=ckpt_path)
    601 assert self.state.stopped
    602 self.training = False

File ~/opt/venv/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py:1012, in Trainer._run(self, model, ckpt_path)
   1007 self._signal_connector.register_signal_handlers()
   1009 # ----------------------------
   1010 # RUN THE TRAINER
   1011 # ----------------------------
-> 1012 results = self._run_stage()
   1014 # ----------------------------
   1015 # POST-Training CLEAN UP
   1016 # ----------------------------
   1017 log.debug(f"{self.__class__.__name__}: trainer tearing down")

File ~/opt/venv/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py:1054, in Trainer._run_stage(self)
   1052 if self.training:
   1053     with isolate_rng():
-> 1054         self._run_sanity_check()
   1055     with torch.autograd.set_detect_anomaly(self._detect_anomaly):
   1056         self.fit_loop.run()

File ~/opt/venv/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py:1083, in Trainer._run_sanity_check(self)
   1080 call._call_callback_hooks(self, "on_sanity_check_start")
   1082 # run eval step
-> 1083 val_loop.run()
   1085 call._call_callback_hooks(self, "on_sanity_check_end")
   1087 # reset logger connector

File ~/opt/venv/lib/python3.11/site-packages/lightning/pytorch/loops/utilities.py:179, in _no_grad_context.<locals>._decorator(self, *args, **kwargs)
    177     context_manager = torch.no_grad
    178 with context_manager():
--> 179     return loop_run(self, *args, **kwargs)

File ~/opt/venv/lib/python3.11/site-packages/lightning/pytorch/loops/evaluation_loop.py:145, in _EvaluationLoop.run(self)
    143     self.batch_progress.is_last_batch = data_fetcher.done
    144     # run step hooks
--> 145     self._evaluation_step(batch, batch_idx, dataloader_idx, dataloader_iter)
    146 except StopIteration:
    147     # this needs to wrap the `*_step` call too (not just `next`) for `dataloader_iter` support
    148     break

File ~/opt/venv/lib/python3.11/site-packages/lightning/pytorch/loops/evaluation_loop.py:437, in _EvaluationLoop._evaluation_step(self, batch, batch_idx, dataloader_idx, dataloader_iter)
    431 hook_name = "test_step" if trainer.testing else "validation_step"
    432 step_args = (
    433     self._build_step_args_from_hook_kwargs(hook_kwargs, hook_name)
    434     if not using_dataloader_iter
    435     else (dataloader_iter,)
    436 )
--> 437 output = call._call_strategy_hook(trainer, hook_name, *step_args)
    439 self.batch_progress.increment_processed()
    441 if using_dataloader_iter:
    442     # update the hook kwargs now that the step method might have consumed the iterator

File ~/opt/venv/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py:328, in _call_strategy_hook(trainer, hook_name, *args, **kwargs)
    325     return None
    327 with trainer.profiler.profile(f"[Strategy]{trainer.strategy.__class__.__name__}.{hook_name}"):
--> 328     output = fn(*args, **kwargs)
    330 # restore current_fx when nested context
    331 pl_module._current_fx_name = prev_fx_name

File ~/opt/venv/lib/python3.11/site-packages/lightning/pytorch/strategies/strategy.py:412, in Strategy.validation_step(self, *args, **kwargs)
    410 if self.model != self.lightning_module:
    411     return self._forward_redirection(self.model, self.lightning_module, "validation_step", *args, **kwargs)
--> 412 return self.lightning_module.validation_step(*args, **kwargs)

File ~/autogluon/multimodal/src/autogluon/multimodal/optim/lit_module.py:381, in LitModule.validation_step(self, batch, batch_idx)
    365 def validation_step(self, batch, batch_idx):
    366     """
    367     Per validation step. This function is registered by LightningModule.
    368     Refer to https://lightning.ai/docs/pytorch/stable/common/lightning_module.html#validation-loop
   (...)
    379         Index of mini-batch.
    380     """
--> 381     output, loss = self._shared_step(batch)
    382     if self.model_postprocess_fn:
    383         output = self.model_postprocess_fn(output)

File ~/autogluon/multimodal/src/autogluon/multimodal/optim/lit_module.py:305, in LitModule._shared_step(self, batch)
    303     self.mixup_fn.mixup_enabled = self.training & (self.current_epoch < self.hparams.mixup_off_epoch)
    304     batch, label = multimodel_mixup(batch=batch, model=self.model, mixup_fn=self.mixup_fn)
--> 305 output = run_model(self.model, batch)
    306 loss = self._compute_loss(output=output, label=label)
    307 return output, loss

File ~/autogluon/multimodal/src/autogluon/multimodal/models/utils.py:865, in run_model(model, batch, trt_model)
    863         output_vec = pure_model(*tuple(input_vec))
    864     else:
--> 865         output_vec = model(*tuple(input_vec))
    867     output = pure_model.get_output_dict(*output_vec)
    868 else:

File ~/opt/venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1739, in Module._wrapped_call_impl(self, *args, **kwargs)
   1737     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1738 else:
-> 1739     return self._call_impl(*args, **kwargs)

File ~/opt/venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1750, in Module._call_impl(self, *args, **kwargs)
   1745 # If we don't have any hooks, we want to skip the rest of the logic in
   1746 # this function, and just call forward.
   1747 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1748         or _global_backward_pre_hooks or _global_backward_hooks
   1749         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1750     return forward_call(*args, **kwargs)
   1752 result = None
   1753 called_always_called_hooks = set()

File ~/autogluon/multimodal/src/autogluon/multimodal/models/hf_text.py:230, in HFAutoModelForTextPrediction.forward(self, text_token_ids, text_segment_ids, text_valid_length, text_column_names, text_column_indices)
    228 else:
    229     if "token_type_ids" in self.tokenizer.model_input_names:
--> 230         outputs = self.model(
    231             input_ids=text_token_ids,
    232             token_type_ids=text_segment_ids,
    233             attention_mask=text_masks,
    234         )
    235     else:
    236         outputs = self.model(
    237             input_ids=text_token_ids,
    238             attention_mask=text_masks,
    239         )

File ~/opt/venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1739, in Module._wrapped_call_impl(self, *args, **kwargs)
   1737     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1738 else:
-> 1739     return self._call_impl(*args, **kwargs)

File ~/opt/venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1750, in Module._call_impl(self, *args, **kwargs)
   1745 # If we don't have any hooks, we want to skip the rest of the logic in
   1746 # this function, and just call forward.
   1747 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1748         or _global_backward_pre_hooks or _global_backward_hooks
   1749         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1750     return forward_call(*args, **kwargs)
   1752 result = None
   1753 called_always_called_hooks = set()

File ~/opt/venv/lib/python3.11/site-packages/transformers/models/deberta_v2/modeling_deberta_v2.py:870, in DebertaV2Model.forward(self, input_ids, attention_mask, token_type_ids, position_ids, inputs_embeds, output_attentions, output_hidden_states, return_dict)
    860     token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
    862 embedding_output = self.embeddings(
    863     input_ids=input_ids,
    864     token_type_ids=token_type_ids,
   (...)
    867     inputs_embeds=inputs_embeds,
    868 )
--> 870 encoder_outputs = self.encoder(
    871     embedding_output,
    872     attention_mask,
    873     output_hidden_states=True,
    874     output_attentions=output_attentions,
    875     return_dict=return_dict,
    876 )
    877 encoded_layers = encoder_outputs[1]
    879 if self.z_steps > 1:

File ~/opt/venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1739, in Module._wrapped_call_impl(self, *args, **kwargs)
   1737     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1738 else:
-> 1739     return self._call_impl(*args, **kwargs)

File ~/opt/venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1750, in Module._call_impl(self, *args, **kwargs)
   1745 # If we don't have any hooks, we want to skip the rest of the logic in
   1746 # this function, and just call forward.
   1747 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1748         or _global_backward_pre_hooks or _global_backward_hooks
   1749         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1750     return forward_call(*args, **kwargs)
   1752 result = None
   1753 called_always_called_hooks = set()

File ~/opt/venv/lib/python3.11/site-packages/transformers/models/deberta_v2/modeling_deberta_v2.py:674, in DebertaV2Encoder.forward(self, hidden_states, attention_mask, output_hidden_states, output_attentions, query_states, relative_pos, return_dict)
    664     output_states, attn_weights = self._gradient_checkpointing_func(
    665         layer_module.__call__,
    666         next_kv,
   (...)
    671         output_attentions,
    672     )
    673 else:
--> 674     output_states, attn_weights = layer_module(
    675         next_kv,
    676         attention_mask,
    677         query_states=query_states,
    678         relative_pos=relative_pos,
    679         rel_embeddings=rel_embeddings,
    680         output_attentions=output_attentions,
    681     )
    683 if output_attentions:
    684     all_attentions = all_attentions + (attn_weights,)

File ~/opt/venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1739, in Module._wrapped_call_impl(self, *args, **kwargs)
   1737     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1738 else:
-> 1739     return self._call_impl(*args, **kwargs)

File ~/opt/venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1750, in Module._call_impl(self, *args, **kwargs)
   1745 # If we don't have any hooks, we want to skip the rest of the logic in
   1746 # this function, and just call forward.
   1747 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1748         or _global_backward_pre_hooks or _global_backward_hooks
   1749         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1750     return forward_call(*args, **kwargs)
   1752 result = None
   1753 called_always_called_hooks = set()

File ~/opt/venv/lib/python3.11/site-packages/transformers/models/deberta_v2/modeling_deberta_v2.py:442, in DebertaV2Layer.forward(self, hidden_states, attention_mask, query_states, relative_pos, rel_embeddings, output_attentions)
    433 def forward(
    434     self,
    435     hidden_states,
   (...)
    440     output_attentions: bool = False,
    441 ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
--> 442     attention_output, att_matrix = self.attention(
    443         hidden_states,
    444         attention_mask,
    445         output_attentions=output_attentions,
    446         query_states=query_states,
    447         relative_pos=relative_pos,
    448         rel_embeddings=rel_embeddings,
    449     )
    450     intermediate_output = self.intermediate(attention_output)
    451     layer_output = self.output(intermediate_output, attention_output)

File ~/opt/venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1739, in Module._wrapped_call_impl(self, *args, **kwargs)
   1737     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1738 else:
-> 1739     return self._call_impl(*args, **kwargs)

File ~/opt/venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1750, in Module._call_impl(self, *args, **kwargs)
   1745 # If we don't have any hooks, we want to skip the rest of the logic in
   1746 # this function, and just call forward.
   1747 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1748         or _global_backward_pre_hooks or _global_backward_hooks
   1749         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1750     return forward_call(*args, **kwargs)
   1752 result = None
   1753 called_always_called_hooks = set()

File ~/opt/venv/lib/python3.11/site-packages/transformers/models/deberta_v2/modeling_deberta_v2.py:375, in DebertaV2Attention.forward(self, hidden_states, attention_mask, output_attentions, query_states, relative_pos, rel_embeddings)
    366 def forward(
    367     self,
    368     hidden_states,
   (...)
    373     rel_embeddings=None,
    374 ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
--> 375     self_output, att_matrix = self.self(
    376         hidden_states,
    377         attention_mask,
    378         output_attentions,
    379         query_states=query_states,
    380         relative_pos=relative_pos,
    381         rel_embeddings=rel_embeddings,
    382     )
    383     if query_states is None:
    384         query_states = hidden_states

File ~/opt/venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1739, in Module._wrapped_call_impl(self, *args, **kwargs)
   1737     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1738 else:
-> 1739     return self._call_impl(*args, **kwargs)

File ~/opt/venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1750, in Module._call_impl(self, *args, **kwargs)
   1745 # If we don't have any hooks, we want to skip the rest of the logic in
   1746 # this function, and just call forward.
   1747 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1748         or _global_backward_pre_hooks or _global_backward_hooks
   1749         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1750     return forward_call(*args, **kwargs)
   1752 result = None
   1753 called_always_called_hooks = set()

File ~/opt/venv/lib/python3.11/site-packages/transformers/models/deberta_v2/modeling_deberta_v2.py:267, in DisentangledSelfAttention.forward(self, hidden_states, attention_mask, output_attentions, query_states, relative_pos, rel_embeddings)
    262 attention_scores = attention_scores.view(
    263     -1, self.num_attention_heads, attention_scores.size(-2), attention_scores.size(-1)
    264 )
    266 attention_mask = attention_mask.bool()
--> 267 attention_scores = attention_scores.masked_fill(~(attention_mask), torch.finfo(query_layer.dtype).min)
    268 # bsz x height x length x dimension
    269 attention_probs = nn.functional.softmax(attention_scores, dim=-1)

RuntimeError: value cannot be converted to type at::BFloat16 without overflow

可调参数的比例约为总参数的 0.5%。实际上,仅在英文数据上训练的模型可以在测试集上取得不错的性能,即使在德语/日语测试集上也是如此。它取得了与用于文本的 AutoMM - 多语言问题中完全微调相似的结果

score_in_en = predictor.evaluate(test_en_df)
score_in_de = predictor.evaluate(test_de_df)
score_in_jp = predictor.evaluate(test_jp_df)
print('Score in the English Testset:', score_in_en)
print('Score in the German Testset:', score_in_de)
print('Score in the Japanese Testset:', score_in_jp)

在单块 GPU 上训练 FLAN-T5-XL

通过结合梯度检查点和参数高效微调,使用 AWS G4 实例中提供的单块 T4 GPU 微调接近二十亿参数的 google/flan-t5-xl 是可行的。要开启梯度检查点,只需将 "model.hf_text.gradient_checkpointing" 设置为 True。为了加快训练,我们将训练样本数下采样到 200。

# Just for clean the space
clear_cache()
shutil.rmtree(model_path)
from autogluon.multimodal import MultiModalPredictor

train_en_df_downsample = train_en_df.sample(200, random_state=123)

new_model_path = f"./tmp/{uuid.uuid4().hex}-multilingual_ia3_gradient_checkpoint"
predictor = MultiModalPredictor(label="label",
                                path=new_model_path)
predictor.fit(train_en_df_downsample,
              presets="multilingual",
              hyperparameters={
                  "model.hf_text.checkpoint_name": "google/flan-t5-xl",
                  "model.hf_text.gradient_checkpointing": True,
                  "model.hf_text.low_cpu_mem_usage": True,
                  "optim.peft": "ia3_bias",
                  "optim.lr_decay": 0.9,
                  "optim.lr": 3e-03,
                  "optim.end_lr": 3e-03,
                  "optim.max_epochs": 1,
                  "optim.warmup_steps": 0,
                  "env.batch_size": 1,
                  "env.inference_batch_size_ratio": 1
              })

Global seed set to 123
Auto select gpus: [0]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name              | Type                         | Params
-------------------------------------------------------------------
0 | model             | HFAutoModelForTextPrediction | 1.2 B 
1 | validation_metric | AUROC                        | 0     
2 | loss_func         | CrossEntropyLoss             | 0     
-------------------------------------------------------------------
203 K     Trainable params
1.2 B     Non-trainable params
1.2 B     Total params
4,894.913 Total estimated model params size (MB)
Epoch 0, global step 20: 'val_roc_auc' reached 0.88802 (best 0.88802), saving model to '/home/ubuntu/autogluon/docs/tutorials/multimodal/advanced_topics/multilingual_ia3_gradient_checkpoint/epoch=0-step=20.ckpt' as top 1
Epoch 0, global step 40: 'val_roc_auc' reached 0.94531 (best 0.94531), saving model to '/home/ubuntu/autogluon/docs/tutorials/multimodal/advanced_topics/multilingual_ia3_gradient_checkpoint/epoch=0-step=40.ckpt' as top 1
`Trainer.fit` stopped: `max_epochs=1` reached.





<autogluon.multimodal.predictor.MultiModalPredictor at 0x7fd58c4dbca0>
score_in_en = predictor.evaluate(test_en_df)
print('Score in the English Testset:', score_in_en)
Score in the English Testset: {'roc_auc': 0.931263189629183}
# Just for clean the space
clear_cache()
shutil.rmtree(new_model_path)

其他示例

您可以前往AutoMM 示例以探索有关 AutoMM 的其他示例。

自定义

要了解如何自定义 AutoMM,请参考自定义 AutoMM