通过参数高效微调在单块GPU上训练十亿级模型¶
正如斯坦福大学以人为中心的智能研究所的一篇最新论文所指出的,随着“基础模型”的兴起,人工智能正在经历一场范式转变,即通常以自监督方式在各种数据集上训练的巨型模型。这些基础模型是AutoMM的关键,可以轻松应用于下游任务。然而,随着这些基础模型规模的增大,微调它们变得越来越困难。下面是摘自微软研究博客的一张图,展示了这一趋势
AutoMM的目标是帮助任何人通过开源基础模型(包括这些巨型模型)解决机器学习问题。为了微调这些大规模模型,我们采用了最近流行的参数高效微调技术。其思路是微调基础模型中一小部分权重(例如,BitFit),或者在固定的主干网络之上添加一个微小的可调结构(例如,Prompt Tuning、LoRA、Adapter、MAM Adapter、IA^3)。这些技术可以有效降低峰值内存使用量和模型训练时间,同时保持性能。
在本教程中,我们将介绍如何在 MultiModalPredictor 中应用参数高效微调。首先,我们将介绍如何采用"ia3_bias"
算法进行参数高效微调。然后,我们将展示如何简单地结合"ia3_bias"
和梯度检查点技术,通过一块 NVIDIA T4 GPU 对 Google 的 FLAN-T5 XL 版本进行微调。
准备数据集¶
跨语言亚马逊产品评论情感数据集包含四种语言的亚马逊产品评论。在这里,我们加载数据集的英文和德语部分。在标签列中,0
表示负面情感,1
表示正面情感。为了演示目的,我们将训练数据下采样到 1000 个样本。我们将在英文数据集上训练模型,并直接在德语和日语测试集上评估其性能。
!wget --quiet https://automl-mm-bench.s3.amazonaws.com/multilingual-datasets/amazon_review_sentiment_cross_lingual.zip -O amazon_review_sentiment_cross_lingual.zip
!unzip -q -o amazon_review_sentiment_cross_lingual.zip -d .
import os
import shutil
os.environ["TRANSFORMERS_CACHE"] = "cache"
def clear_cache():
if os.path.exists("cache"):
shutil.rmtree("cache")
clear_cache()
import pandas as pd
import warnings
warnings.filterwarnings("ignore")
train_en_df = pd.read_csv("amazon_review_sentiment_cross_lingual/en_train.tsv",
sep="\t",
header=None,
names=["label", "text"]) \
.sample(1000, random_state=123).reset_index(drop=True)
test_en_df = pd.read_csv("amazon_review_sentiment_cross_lingual/en_test.tsv",
sep="\t",
header=None,
names=["label", "text"]) \
.sample(200, random_state=123).reset_index(drop=True)
test_de_df = pd.read_csv("amazon_review_sentiment_cross_lingual/de_test.tsv",
sep="\t", header=None, names=["label", "text"]) \
.sample(200, random_state=123).reset_index(drop=True)
test_jp_df = pd.read_csv('amazon_review_sentiment_cross_lingual/jp_test.tsv',
sep='\t', header=None, names=['label', 'text']) \
.sample(200, random_state=123).reset_index(drop=True)
train_en_df.head(5)
标签 | 文本 | |
---|---|---|
0 | 0 | 这是一部字面上看几乎没有错... |
1 | 0 | 这段音乐很智能,但不太... |
2 | 0 | 有史以来录制的最好的摇滚乐作品之一... |
3 | 0 | 阅读这里发布的评论,就像重温... |
4 | 1 | 我刚刚读完第 341 页,最后一页。它... |
test_jp_df.head(5)
标签 | 文本 | |
---|---|---|
0 | 1 | 原作はビクトル・ユーゴの長編小説だが、私が子供の頃読んだのは短縮版の「ああ無情」。それでもこ... |
1 | 1 | ほかの作品のレビューにみんな書いているのに、何故この作品について書いている人が一人しかいない... |
2 | 0 | 一番の問題点は青島が出ていない事でしょう。 TV番組では『芸人が出ていればバラエティだから... |
3 | 0 | 昔、 りんたろう監督によるアニメ「カムイの剣」があった。 「カムイの剣」…を観た人なら本作... |
4 | 1 | 以前のアルバムを聴いていないのでなんとも言えないが、クラシックなメタルを聞いてきた耳には、と... |
使用 IA3 + BitFit 微调多语言模型¶
在 AutoMM 中,要启用高效微调,只需将 optim.peft
指定为 "ia3_bias"
。
from autogluon.multimodal import MultiModalPredictor
import uuid
model_path = f"./tmp/{uuid.uuid4().hex}-multilingual_ia3"
predictor = MultiModalPredictor(label="label",
path=model_path)
predictor.fit(train_en_df,
presets="multilingual",
hyperparameters={
"optim.peft": "ia3_bias",
"optim.lr_decay": 0.9,
"optim.lr": 3e-03,
"optim.end_lr": 3e-03,
"optim.max_epochs": 2,
"optim.warmup_steps": 0,
"env.batch_size": 32,
})
=================== System Info ===================
AutoGluon Version: 1.3.1b20250508
Python Version: 3.11.9
Operating System: Linux
Platform Machine: x86_64
Platform Version: #1 SMP Wed Mar 12 14:53:59 UTC 2025
CPU Count: 8
Pytorch Version: 2.6.0+cu124
CUDA Version: 12.4
Memory Avail: 28.40 GB / 30.95 GB (91.8%)
Disk Space Avail: 181.97 GB / 255.99 GB (71.1%)
===================================================
AutoGluon infers your prediction problem is: 'binary' (because only two unique label-values observed).
2 unique label values: [np.int64(0), np.int64(1)]
If 'binary' is not the correct problem_type, please manually specify the problem_type parameter during Predictor init (You may specify problem_type as one of: ['binary', 'multiclass', 'regression', 'quantile'])
AutoMM starts to create your model. ✨✨✨
To track the learning progress, you can open a terminal and launch Tensorboard:
```shell
# Assume you have installed tensorboard
tensorboard --logdir /home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/tmp/dd8a69b4dd8f41d58155c12aaf6ec843-multilingual_ia3
```
Seed set to 0
GPU Count: 1
GPU Count to be Used: 1
Using bfloat16 Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
| Name | Type | Params | Mode
---------------------------------------------------------------------------
0 | model | HFAutoModelForTextPrediction | 278 M | train
1 | validation_metric | BinaryAUROC | 0 | train
2 | loss_func | CrossEntropyLoss | 0 | train
---------------------------------------------------------------------------
122 K Trainable params
278 M Non-trainable params
278 M Total params
1,112.955 Total estimated model params size (MB)
241 Modules in train mode
0 Modules in eval mode
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[6], line 7
4 model_path = f"./tmp/{uuid.uuid4().hex}-multilingual_ia3"
5 predictor = MultiModalPredictor(label="label",
6 path=model_path)
----> 7 predictor.fit(train_en_df,
8 presets="multilingual",
9 hyperparameters={
10 "optim.peft": "ia3_bias",
11 "optim.lr_decay": 0.9,
12 "optim.lr": 3e-03,
13 "optim.end_lr": 3e-03,
14 "optim.max_epochs": 2,
15 "optim.warmup_steps": 0,
16 "env.batch_size": 32,
17 })
File ~/autogluon/multimodal/src/autogluon/multimodal/predictor.py:540, in MultiModalPredictor.fit(self, train_data, presets, tuning_data, max_num_tuning_data, id_mappings, time_limit, save_path, hyperparameters, column_types, holdout_frac, teacher_predictor, seed, standalone, hyperparameter_tune_kwargs, clean_ckpts, predictions, labels, predictors)
537 assert isinstance(predictors, list)
538 learners = [ele if isinstance(ele, str) else ele._learner for ele in predictors]
--> 540 self._learner.fit(
541 train_data=train_data,
542 presets=presets,
543 tuning_data=tuning_data,
544 max_num_tuning_data=max_num_tuning_data,
545 time_limit=time_limit,
546 save_path=save_path,
547 hyperparameters=hyperparameters,
548 column_types=column_types,
549 holdout_frac=holdout_frac,
550 teacher_learner=teacher_learner,
551 seed=seed,
552 standalone=standalone,
553 hyperparameter_tune_kwargs=hyperparameter_tune_kwargs,
554 clean_ckpts=clean_ckpts,
555 id_mappings=id_mappings,
556 predictions=predictions,
557 labels=labels,
558 learners=learners,
559 )
561 return self
File ~/autogluon/multimodal/src/autogluon/multimodal/learners/base.py:665, in BaseLearner.fit(self, train_data, presets, tuning_data, time_limit, save_path, hyperparameters, column_types, holdout_frac, teacher_learner, seed, standalone, hyperparameter_tune_kwargs, clean_ckpts, **kwargs)
658 self.fit_sanity_check()
659 self.prepare_fit_args(
660 time_limit=time_limit,
661 seed=seed,
662 standalone=standalone,
663 clean_ckpts=clean_ckpts,
664 )
--> 665 fit_returns = self.execute_fit()
666 self.on_fit_end(
667 training_start=training_start,
668 strategy=fit_returns.get("strategy", None),
(...)
671 clean_ckpts=clean_ckpts,
672 )
674 return self
File ~/autogluon/multimodal/src/autogluon/multimodal/learners/base.py:577, in BaseLearner.execute_fit(self)
575 return dict()
576 else:
--> 577 attributes = self.fit_per_run(**self._fit_args)
578 self.update_attributes(**attributes) # only update attributes for non-HPO mode
579 return attributes
File ~/autogluon/multimodal/src/autogluon/multimodal/learners/base.py:1358, in BaseLearner.fit_per_run(self, max_time, save_path, ckpt_path, resume, enable_progress_bar, seed, hyperparameters, advanced_hyperparameters, config, df_preprocessor, data_processors, model, standalone, clean_ckpts)
1339 config = self.post_update_config_per_run(
1340 config=config,
1341 num_gpus=num_gpus,
1342 precision=precision,
1343 strategy=strategy,
1344 )
1345 trainer = self.init_trainer_per_run(
1346 num_gpus=num_gpus,
1347 config=config,
(...)
1355 enable_progress_bar=enable_progress_bar,
1356 )
-> 1358 self.run_trainer(
1359 trainer=trainer,
1360 litmodule=litmodule,
1361 datamodule=datamodule,
1362 ckpt_path=ckpt_path,
1363 resume=resume,
1364 )
1365 self.on_fit_per_run_end(
1366 save_path=save_path,
1367 standalone=standalone,
(...)
1372 model=model,
1373 )
1375 best_score = (
1376 trainer.callback_metrics[f"val_{self._validation_metric_name}"].item()
1377 if f"val_{self._validation_metric_name}" in trainer.callback_metrics
1378 else self._best_score
1379 ) # https://github.com/autogluon/autogluon/issues/4428
File ~/autogluon/multimodal/src/autogluon/multimodal/learners/base.py:1211, in BaseLearner.run_trainer(self, trainer, litmodule, datamodule, ckpt_path, resume, pred_writer, is_train)
1209 warnings.filterwarnings("ignore", filter)
1210 if is_train:
-> 1211 trainer.fit(
1212 litmodule,
1213 datamodule=datamodule,
1214 ckpt_path=ckpt_path if resume else None, # this is to resume training that was broken accidentally
1215 )
1216 else:
1217 blacklist_msgs = []
File ~/opt/venv/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py:561, in Trainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
559 self.training = True
560 self.should_stop = False
--> 561 call._call_and_handle_interrupt(
562 self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
563 )
File ~/opt/venv/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py:48, in _call_and_handle_interrupt(trainer, trainer_fn, *args, **kwargs)
46 if trainer.strategy.launcher is not None:
47 return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
---> 48 return trainer_fn(*args, **kwargs)
50 except _TunerExitException:
51 _call_teardown_hook(trainer)
File ~/opt/venv/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py:599, in Trainer._fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
592 download_model_from_registry(ckpt_path, self)
593 ckpt_path = self._checkpoint_connector._select_ckpt_path(
594 self.state.fn,
595 ckpt_path,
596 model_provided=True,
597 model_connected=self.lightning_module is not None,
598 )
--> 599 self._run(model, ckpt_path=ckpt_path)
601 assert self.state.stopped
602 self.training = False
File ~/opt/venv/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py:1012, in Trainer._run(self, model, ckpt_path)
1007 self._signal_connector.register_signal_handlers()
1009 # ----------------------------
1010 # RUN THE TRAINER
1011 # ----------------------------
-> 1012 results = self._run_stage()
1014 # ----------------------------
1015 # POST-Training CLEAN UP
1016 # ----------------------------
1017 log.debug(f"{self.__class__.__name__}: trainer tearing down")
File ~/opt/venv/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py:1054, in Trainer._run_stage(self)
1052 if self.training:
1053 with isolate_rng():
-> 1054 self._run_sanity_check()
1055 with torch.autograd.set_detect_anomaly(self._detect_anomaly):
1056 self.fit_loop.run()
File ~/opt/venv/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py:1083, in Trainer._run_sanity_check(self)
1080 call._call_callback_hooks(self, "on_sanity_check_start")
1082 # run eval step
-> 1083 val_loop.run()
1085 call._call_callback_hooks(self, "on_sanity_check_end")
1087 # reset logger connector
File ~/opt/venv/lib/python3.11/site-packages/lightning/pytorch/loops/utilities.py:179, in _no_grad_context.<locals>._decorator(self, *args, **kwargs)
177 context_manager = torch.no_grad
178 with context_manager():
--> 179 return loop_run(self, *args, **kwargs)
File ~/opt/venv/lib/python3.11/site-packages/lightning/pytorch/loops/evaluation_loop.py:145, in _EvaluationLoop.run(self)
143 self.batch_progress.is_last_batch = data_fetcher.done
144 # run step hooks
--> 145 self._evaluation_step(batch, batch_idx, dataloader_idx, dataloader_iter)
146 except StopIteration:
147 # this needs to wrap the `*_step` call too (not just `next`) for `dataloader_iter` support
148 break
File ~/opt/venv/lib/python3.11/site-packages/lightning/pytorch/loops/evaluation_loop.py:437, in _EvaluationLoop._evaluation_step(self, batch, batch_idx, dataloader_idx, dataloader_iter)
431 hook_name = "test_step" if trainer.testing else "validation_step"
432 step_args = (
433 self._build_step_args_from_hook_kwargs(hook_kwargs, hook_name)
434 if not using_dataloader_iter
435 else (dataloader_iter,)
436 )
--> 437 output = call._call_strategy_hook(trainer, hook_name, *step_args)
439 self.batch_progress.increment_processed()
441 if using_dataloader_iter:
442 # update the hook kwargs now that the step method might have consumed the iterator
File ~/opt/venv/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py:328, in _call_strategy_hook(trainer, hook_name, *args, **kwargs)
325 return None
327 with trainer.profiler.profile(f"[Strategy]{trainer.strategy.__class__.__name__}.{hook_name}"):
--> 328 output = fn(*args, **kwargs)
330 # restore current_fx when nested context
331 pl_module._current_fx_name = prev_fx_name
File ~/opt/venv/lib/python3.11/site-packages/lightning/pytorch/strategies/strategy.py:412, in Strategy.validation_step(self, *args, **kwargs)
410 if self.model != self.lightning_module:
411 return self._forward_redirection(self.model, self.lightning_module, "validation_step", *args, **kwargs)
--> 412 return self.lightning_module.validation_step(*args, **kwargs)
File ~/autogluon/multimodal/src/autogluon/multimodal/optim/lit_module.py:381, in LitModule.validation_step(self, batch, batch_idx)
365 def validation_step(self, batch, batch_idx):
366 """
367 Per validation step. This function is registered by LightningModule.
368 Refer to https://lightning.ai/docs/pytorch/stable/common/lightning_module.html#validation-loop
(...)
379 Index of mini-batch.
380 """
--> 381 output, loss = self._shared_step(batch)
382 if self.model_postprocess_fn:
383 output = self.model_postprocess_fn(output)
File ~/autogluon/multimodal/src/autogluon/multimodal/optim/lit_module.py:305, in LitModule._shared_step(self, batch)
303 self.mixup_fn.mixup_enabled = self.training & (self.current_epoch < self.hparams.mixup_off_epoch)
304 batch, label = multimodel_mixup(batch=batch, model=self.model, mixup_fn=self.mixup_fn)
--> 305 output = run_model(self.model, batch)
306 loss = self._compute_loss(output=output, label=label)
307 return output, loss
File ~/autogluon/multimodal/src/autogluon/multimodal/models/utils.py:865, in run_model(model, batch, trt_model)
863 output_vec = pure_model(*tuple(input_vec))
864 else:
--> 865 output_vec = model(*tuple(input_vec))
867 output = pure_model.get_output_dict(*output_vec)
868 else:
File ~/opt/venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1739, in Module._wrapped_call_impl(self, *args, **kwargs)
1737 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1738 else:
-> 1739 return self._call_impl(*args, **kwargs)
File ~/opt/venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1750, in Module._call_impl(self, *args, **kwargs)
1745 # If we don't have any hooks, we want to skip the rest of the logic in
1746 # this function, and just call forward.
1747 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1748 or _global_backward_pre_hooks or _global_backward_hooks
1749 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1750 return forward_call(*args, **kwargs)
1752 result = None
1753 called_always_called_hooks = set()
File ~/autogluon/multimodal/src/autogluon/multimodal/models/hf_text.py:230, in HFAutoModelForTextPrediction.forward(self, text_token_ids, text_segment_ids, text_valid_length, text_column_names, text_column_indices)
228 else:
229 if "token_type_ids" in self.tokenizer.model_input_names:
--> 230 outputs = self.model(
231 input_ids=text_token_ids,
232 token_type_ids=text_segment_ids,
233 attention_mask=text_masks,
234 )
235 else:
236 outputs = self.model(
237 input_ids=text_token_ids,
238 attention_mask=text_masks,
239 )
File ~/opt/venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1739, in Module._wrapped_call_impl(self, *args, **kwargs)
1737 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1738 else:
-> 1739 return self._call_impl(*args, **kwargs)
File ~/opt/venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1750, in Module._call_impl(self, *args, **kwargs)
1745 # If we don't have any hooks, we want to skip the rest of the logic in
1746 # this function, and just call forward.
1747 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1748 or _global_backward_pre_hooks or _global_backward_hooks
1749 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1750 return forward_call(*args, **kwargs)
1752 result = None
1753 called_always_called_hooks = set()
File ~/opt/venv/lib/python3.11/site-packages/transformers/models/deberta_v2/modeling_deberta_v2.py:870, in DebertaV2Model.forward(self, input_ids, attention_mask, token_type_ids, position_ids, inputs_embeds, output_attentions, output_hidden_states, return_dict)
860 token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
862 embedding_output = self.embeddings(
863 input_ids=input_ids,
864 token_type_ids=token_type_ids,
(...)
867 inputs_embeds=inputs_embeds,
868 )
--> 870 encoder_outputs = self.encoder(
871 embedding_output,
872 attention_mask,
873 output_hidden_states=True,
874 output_attentions=output_attentions,
875 return_dict=return_dict,
876 )
877 encoded_layers = encoder_outputs[1]
879 if self.z_steps > 1:
File ~/opt/venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1739, in Module._wrapped_call_impl(self, *args, **kwargs)
1737 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1738 else:
-> 1739 return self._call_impl(*args, **kwargs)
File ~/opt/venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1750, in Module._call_impl(self, *args, **kwargs)
1745 # If we don't have any hooks, we want to skip the rest of the logic in
1746 # this function, and just call forward.
1747 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1748 or _global_backward_pre_hooks or _global_backward_hooks
1749 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1750 return forward_call(*args, **kwargs)
1752 result = None
1753 called_always_called_hooks = set()
File ~/opt/venv/lib/python3.11/site-packages/transformers/models/deberta_v2/modeling_deberta_v2.py:674, in DebertaV2Encoder.forward(self, hidden_states, attention_mask, output_hidden_states, output_attentions, query_states, relative_pos, return_dict)
664 output_states, attn_weights = self._gradient_checkpointing_func(
665 layer_module.__call__,
666 next_kv,
(...)
671 output_attentions,
672 )
673 else:
--> 674 output_states, attn_weights = layer_module(
675 next_kv,
676 attention_mask,
677 query_states=query_states,
678 relative_pos=relative_pos,
679 rel_embeddings=rel_embeddings,
680 output_attentions=output_attentions,
681 )
683 if output_attentions:
684 all_attentions = all_attentions + (attn_weights,)
File ~/opt/venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1739, in Module._wrapped_call_impl(self, *args, **kwargs)
1737 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1738 else:
-> 1739 return self._call_impl(*args, **kwargs)
File ~/opt/venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1750, in Module._call_impl(self, *args, **kwargs)
1745 # If we don't have any hooks, we want to skip the rest of the logic in
1746 # this function, and just call forward.
1747 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1748 or _global_backward_pre_hooks or _global_backward_hooks
1749 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1750 return forward_call(*args, **kwargs)
1752 result = None
1753 called_always_called_hooks = set()
File ~/opt/venv/lib/python3.11/site-packages/transformers/models/deberta_v2/modeling_deberta_v2.py:442, in DebertaV2Layer.forward(self, hidden_states, attention_mask, query_states, relative_pos, rel_embeddings, output_attentions)
433 def forward(
434 self,
435 hidden_states,
(...)
440 output_attentions: bool = False,
441 ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
--> 442 attention_output, att_matrix = self.attention(
443 hidden_states,
444 attention_mask,
445 output_attentions=output_attentions,
446 query_states=query_states,
447 relative_pos=relative_pos,
448 rel_embeddings=rel_embeddings,
449 )
450 intermediate_output = self.intermediate(attention_output)
451 layer_output = self.output(intermediate_output, attention_output)
File ~/opt/venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1739, in Module._wrapped_call_impl(self, *args, **kwargs)
1737 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1738 else:
-> 1739 return self._call_impl(*args, **kwargs)
File ~/opt/venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1750, in Module._call_impl(self, *args, **kwargs)
1745 # If we don't have any hooks, we want to skip the rest of the logic in
1746 # this function, and just call forward.
1747 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1748 or _global_backward_pre_hooks or _global_backward_hooks
1749 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1750 return forward_call(*args, **kwargs)
1752 result = None
1753 called_always_called_hooks = set()
File ~/opt/venv/lib/python3.11/site-packages/transformers/models/deberta_v2/modeling_deberta_v2.py:375, in DebertaV2Attention.forward(self, hidden_states, attention_mask, output_attentions, query_states, relative_pos, rel_embeddings)
366 def forward(
367 self,
368 hidden_states,
(...)
373 rel_embeddings=None,
374 ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
--> 375 self_output, att_matrix = self.self(
376 hidden_states,
377 attention_mask,
378 output_attentions,
379 query_states=query_states,
380 relative_pos=relative_pos,
381 rel_embeddings=rel_embeddings,
382 )
383 if query_states is None:
384 query_states = hidden_states
File ~/opt/venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1739, in Module._wrapped_call_impl(self, *args, **kwargs)
1737 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1738 else:
-> 1739 return self._call_impl(*args, **kwargs)
File ~/opt/venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1750, in Module._call_impl(self, *args, **kwargs)
1745 # If we don't have any hooks, we want to skip the rest of the logic in
1746 # this function, and just call forward.
1747 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1748 or _global_backward_pre_hooks or _global_backward_hooks
1749 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1750 return forward_call(*args, **kwargs)
1752 result = None
1753 called_always_called_hooks = set()
File ~/opt/venv/lib/python3.11/site-packages/transformers/models/deberta_v2/modeling_deberta_v2.py:267, in DisentangledSelfAttention.forward(self, hidden_states, attention_mask, output_attentions, query_states, relative_pos, rel_embeddings)
262 attention_scores = attention_scores.view(
263 -1, self.num_attention_heads, attention_scores.size(-2), attention_scores.size(-1)
264 )
266 attention_mask = attention_mask.bool()
--> 267 attention_scores = attention_scores.masked_fill(~(attention_mask), torch.finfo(query_layer.dtype).min)
268 # bsz x height x length x dimension
269 attention_probs = nn.functional.softmax(attention_scores, dim=-1)
RuntimeError: value cannot be converted to type at::BFloat16 without overflow
可调参数的比例约为总参数的 0.5%。实际上,仅在英文数据上训练的模型可以在测试集上取得不错的性能,即使在德语/日语测试集上也是如此。它取得了与用于文本的 AutoMM - 多语言问题中完全微调相似的结果。
score_in_en = predictor.evaluate(test_en_df)
score_in_de = predictor.evaluate(test_de_df)
score_in_jp = predictor.evaluate(test_jp_df)
print('Score in the English Testset:', score_in_en)
print('Score in the German Testset:', score_in_de)
print('Score in the Japanese Testset:', score_in_jp)
在单块 GPU 上训练 FLAN-T5-XL¶
通过结合梯度检查点和参数高效微调,使用 AWS G4 实例中提供的单块 T4 GPU 微调接近二十亿参数的 google/flan-t5-xl 是可行的。要开启梯度检查点,只需将 "model.hf_text.gradient_checkpointing"
设置为 True
。为了加快训练,我们将训练样本数下采样到 200。
# Just for clean the space
clear_cache()
shutil.rmtree(model_path)
from autogluon.multimodal import MultiModalPredictor
train_en_df_downsample = train_en_df.sample(200, random_state=123)
new_model_path = f"./tmp/{uuid.uuid4().hex}-multilingual_ia3_gradient_checkpoint"
predictor = MultiModalPredictor(label="label",
path=new_model_path)
predictor.fit(train_en_df_downsample,
presets="multilingual",
hyperparameters={
"model.hf_text.checkpoint_name": "google/flan-t5-xl",
"model.hf_text.gradient_checkpointing": True,
"model.hf_text.low_cpu_mem_usage": True,
"optim.peft": "ia3_bias",
"optim.lr_decay": 0.9,
"optim.lr": 3e-03,
"optim.end_lr": 3e-03,
"optim.max_epochs": 1,
"optim.warmup_steps": 0,
"env.batch_size": 1,
"env.inference_batch_size_ratio": 1
})
Global seed set to 123
Auto select gpus: [0]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
| Name | Type | Params
-------------------------------------------------------------------
0 | model | HFAutoModelForTextPrediction | 1.2 B
1 | validation_metric | AUROC | 0
2 | loss_func | CrossEntropyLoss | 0
-------------------------------------------------------------------
203 K Trainable params
1.2 B Non-trainable params
1.2 B Total params
4,894.913 Total estimated model params size (MB)
Epoch 0, global step 20: 'val_roc_auc' reached 0.88802 (best 0.88802), saving model to '/home/ubuntu/autogluon/docs/tutorials/multimodal/advanced_topics/multilingual_ia3_gradient_checkpoint/epoch=0-step=20.ckpt' as top 1
Epoch 0, global step 40: 'val_roc_auc' reached 0.94531 (best 0.94531), saving model to '/home/ubuntu/autogluon/docs/tutorials/multimodal/advanced_topics/multilingual_ia3_gradient_checkpoint/epoch=0-step=40.ckpt' as top 1
`Trainer.fit` stopped: `max_epochs=1` reached.
<autogluon.multimodal.predictor.MultiModalPredictor at 0x7fd58c4dbca0>
score_in_en = predictor.evaluate(test_en_df)
print('Score in the English Testset:', score_in_en)
Score in the English Testset: {'roc_auc': 0.931263189629183}
# Just for clean the space
clear_cache()
shutil.rmtree(new_model_path)
其他示例¶
您可以前往AutoMM 示例以探索有关 AutoMM 的其他示例。
自定义¶
要了解如何自定义 AutoMM,请参考自定义 AutoMM。