AutoMM 中的知识蒸馏

Open In Colab Open In SageMaker Studio Lab

预训练的基础模型正变得越来越大。然而,由于部署场景中可用资源的限制,这些模型很难部署。为了在这种约束下从大型模型中受益,您可以通过知识蒸馏将大型教师模型的知识转移到学生模型。通过这种方式,小型学生模型可以在实际场景下得到实际部署,同时由于教师模型的指导,性能会比从头开始训练学生模型更好。

在本教程中,我们将介绍如何使用 MultiModalPredictor 进行知识蒸馏。为了演示目的,我们使用了 问答型 NLI 数据集,该数据集包含从问答数据集中抽样的 104,743 个问答对。我们将演示如何在 AutoGluon 中使用大型模型来指导小型模型的学习并提高其性能。

加载数据集

问答型 NLI 数据集包含英文句子对。在标签列中,0 表示句子与问题不相关,1 表示句子与问题相关。

import datasets
from datasets import load_dataset

datasets.logging.disable_progress_bar()

dataset = load_dataset("glue", "qnli")
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[2], line 6
      2 from datasets import load_dataset
      4 datasets.logging.disable_progress_bar()
----> 6 dataset = load_dataset("glue", "qnli")

File ~/opt/venv/lib/python3.11/site-packages/datasets/load.py:2112, in load_dataset(path, name, data_dir, data_files, split, cache_dir, features, download_config, download_mode, verification_mode, ignore_verifications, keep_in_memory, save_infos, revision, token, use_auth_token, task, streaming, num_proc, storage_options, **config_kwargs)
   2107 verification_mode = VerificationMode(
   2108     (verification_mode or VerificationMode.BASIC_CHECKS) if not save_infos else VerificationMode.ALL_CHECKS
   2109 )
   2111 # Create a dataset builder
-> 2112 builder_instance = load_dataset_builder(
   2113     path=path,
   2114     name=name,
   2115     data_dir=data_dir,
   2116     data_files=data_files,
   2117     cache_dir=cache_dir,
   2118     features=features,
   2119     download_config=download_config,
   2120     download_mode=download_mode,
   2121     revision=revision,
   2122     token=token,
   2123     storage_options=storage_options,
   2124     **config_kwargs,
   2125 )
   2127 # Return iterable dataset in case of streaming
   2128 if streaming:

File ~/opt/venv/lib/python3.11/site-packages/datasets/load.py:1798, in load_dataset_builder(path, name, data_dir, data_files, cache_dir, features, download_config, download_mode, revision, token, use_auth_token, storage_options, **config_kwargs)
   1796     download_config = download_config.copy() if download_config else DownloadConfig()
   1797     download_config.storage_options.update(storage_options)
-> 1798 dataset_module = dataset_module_factory(
   1799     path,
   1800     revision=revision,
   1801     download_config=download_config,
   1802     download_mode=download_mode,
   1803     data_dir=data_dir,
   1804     data_files=data_files,
   1805 )
   1806 # Get dataset builder class from the processing script
   1807 builder_kwargs = dataset_module.builder_kwargs

File ~/opt/venv/lib/python3.11/site-packages/datasets/load.py:1495, in dataset_module_factory(path, revision, download_config, download_mode, dynamic_modules_path, data_dir, data_files, **download_kwargs)
   1490             if isinstance(e1, FileNotFoundError):
   1491                 raise FileNotFoundError(
   1492                     f"Couldn't find a dataset script at {relative_to_absolute_path(combined_path)} or any data file in the same directory. "
   1493                     f"Couldn't find '{path}' on the Hugging Face Hub either: {type(e1).__name__}: {e1}"
   1494                 ) from None
-> 1495             raise e1 from None
   1496 else:
   1497     raise FileNotFoundError(
   1498         f"Couldn't find a dataset script at {relative_to_absolute_path(combined_path)} or any data file in the same directory."
   1499     )

File ~/opt/venv/lib/python3.11/site-packages/datasets/load.py:1479, in dataset_module_factory(path, revision, download_config, download_mode, dynamic_modules_path, data_dir, data_files, **download_kwargs)
   1464         return HubDatasetModuleFactoryWithScript(
   1465             path,
   1466             revision=revision,
   (...)
   1469             dynamic_modules_path=dynamic_modules_path,
   1470         ).get_module()
   1471     else:
   1472         return HubDatasetModuleFactoryWithoutScript(
   1473             path,
   1474             revision=revision,
   1475             data_dir=data_dir,
   1476             data_files=data_files,
   1477             download_config=download_config,
   1478             download_mode=download_mode,
-> 1479         ).get_module()
   1480 except (
   1481     Exception
   1482 ) as e1:  # noqa all the attempts failed, before raising the error we should check if the module is already cached.
   1483     try:

File ~/opt/venv/lib/python3.11/site-packages/datasets/load.py:1034, in HubDatasetModuleFactoryWithoutScript.get_module(self)
   1029 metadata_configs = MetadataConfigs.from_dataset_card_data(dataset_card_data)
   1030 dataset_infos = DatasetInfosDict.from_dataset_card_data(dataset_card_data)
   1031 patterns = (
   1032     sanitize_patterns(self.data_files)
   1033     if self.data_files is not None
-> 1034     else get_data_patterns(base_path, download_config=self.download_config)
   1035 )
   1036 data_files = DataFilesDict.from_patterns(
   1037     patterns,
   1038     base_path=base_path,
   1039     allowed_extensions=ALL_ALLOWED_EXTENSIONS,
   1040     download_config=self.download_config,
   1041 )
   1042 module_name, default_builder_kwargs = infer_module_for_data_files(
   1043     data_files=data_files,
   1044     path=self.name,
   1045     download_config=self.download_config,
   1046 )

File ~/opt/venv/lib/python3.11/site-packages/datasets/data_files.py:457, in get_data_patterns(base_path, download_config)
    455 resolver = partial(resolve_pattern, base_path=base_path, download_config=download_config)
    456 try:
--> 457     return _get_data_files_patterns(resolver)
    458 except FileNotFoundError:
    459     raise EmptyDatasetError(f"The directory at {base_path} doesn't contain any data files") from None

File ~/opt/venv/lib/python3.11/site-packages/datasets/data_files.py:248, in _get_data_files_patterns(pattern_resolver)
    246 for pattern in patterns:
    247     try:
--> 248         data_files = pattern_resolver(pattern)
    249     except FileNotFoundError:
    250         continue

File ~/opt/venv/lib/python3.11/site-packages/datasets/data_files.py:332, in resolve_pattern(pattern, base_path, allowed_extensions, download_config)
    330     base_path = ""
    331 pattern, storage_options = _prepare_path_and_storage_options(pattern, download_config=download_config)
--> 332 fs, _, _ = get_fs_token_paths(pattern, storage_options=storage_options)
    333 fs_base_path = base_path.split("::")[0].split("://")[-1] or fs.root_marker
    334 fs_pattern = pattern.split("::")[0].split("://")[-1]

File ~/opt/venv/lib/python3.11/site-packages/fsspec/core.py:686, in get_fs_token_paths(urlpath, mode, num, name_function, storage_options, protocol, expand)
    684     paths = _expand_paths(paths, name_function, num)
    685 elif "*" in paths:
--> 686     paths = [f for f in sorted(fs.glob(paths)) if not fs.isdir(f)]
    687 else:
    688     paths = [paths]

File ~/opt/venv/lib/python3.11/site-packages/huggingface_hub/hf_file_system.py:521, in HfFileSystem.glob(self, path, **kwargs)
    519 kwargs = {"expand_info": kwargs.get("detail", False), **kwargs}
    520 path = self.resolve_path(path, revision=kwargs.get("revision")).unresolve()
--> 521 return super().glob(path, **kwargs)

File ~/opt/venv/lib/python3.11/site-packages/fsspec/spec.py:611, in AbstractFileSystem.glob(self, path, maxdepth, **kwargs)
    607         depth = None
    609 allpaths = self.find(root, maxdepth=depth, withdirs=True, detail=True, **kwargs)
--> 611 pattern = glob_translate(path + ("/" if ends_with_sep else ""))
    612 pattern = re.compile(pattern)
    614 out = {
    615     p: info
    616     for p, info in sorted(allpaths.items())
   (...)
    621     )
    622 }

File ~/opt/venv/lib/python3.11/site-packages/fsspec/utils.py:731, in glob_translate(pat)
    729     continue
    730 elif "**" in part:
--> 731     raise ValueError(
    732         "Invalid pattern: '**' can only be an entire path component"
    733     )
    734 if part:
    735     results.extend(_translate(part, f"{not_sep}*", not_sep))

ValueError: Invalid pattern: '**' can only be an entire path component
dataset['train']
from sklearn.model_selection import train_test_split

train_valid_df = dataset["train"].to_pandas()[["question", "sentence", "label"]].sample(1000, random_state=123)
train_df, valid_df = train_test_split(train_valid_df, test_size=0.2, random_state=123)
test_df = dataset["validation"].to_pandas()[["question", "sentence", "label"]].sample(1000, random_state=123)

加载教师模型

在我们的示例中,我们将直接加载一个使用 google/bert_uncased_L-12_H-768_A-12 主干并在 QNLI 上训练过的教师模型,并将其蒸馏到一个使用 google/bert_uncased_L-6_H-768_A-12 主干的学生模型中。

!wget --quiet https://automl-mm-bench.s3.amazonaws.com/unit-tests/distillation_sample_teacher.zip -O distillation_sample_teacher.zip
!unzip -q -o distillation_sample_teacher.zip -d .
from autogluon.multimodal import MultiModalPredictor

teacher_predictor = MultiModalPredictor.load("ag_distillation_sample_teacher/")

蒸馏到学生模型

训练学生模型非常简单。您只需在调用 .fit() 时添加 teacher_predictor 参数即可。在内部,学生模型将通过匹配教师模型的预测/特征图进行训练。这比直接微调学生模型的性能更好。

student_predictor = MultiModalPredictor(label="label")
student_predictor.fit(
    train_df,
    tuning_data=valid_df,
    teacher_predictor=teacher_predictor,
    hyperparameters={
        "model.hf_text.checkpoint_name": "google/bert_uncased_L-6_H-768_A-12",
        "optim.max_epochs": 2,
    }
)
print(student_predictor.evaluate(data=test_df))

更多关于知识蒸馏的信息

要了解如何自定义蒸馏以及它与直接微调的比较,请参阅 AutoMM 蒸馏示例 中的蒸馏示例和 README。特别是包含更多详细信息和自定义的 多语言蒸馏示例

其他示例

您可以前往 AutoMM 示例 探索其他 AutoMM 示例。

自定义

要了解如何自定义 AutoMM,请参考自定义 AutoMM