AutoMM 中的知识蒸馏¶
预训练的基础模型正变得越来越大。然而,由于部署场景中可用资源的限制,这些模型很难部署。为了在这种约束下从大型模型中受益,您可以通过知识蒸馏将大型教师模型的知识转移到学生模型。通过这种方式,小型学生模型可以在实际场景下得到实际部署,同时由于教师模型的指导,性能会比从头开始训练学生模型更好。
在本教程中,我们将介绍如何使用 MultiModalPredictor
进行知识蒸馏。为了演示目的,我们使用了 问答型 NLI 数据集,该数据集包含从问答数据集中抽样的 104,743 个问答对。我们将演示如何在 AutoGluon 中使用大型模型来指导小型模型的学习并提高其性能。
加载数据集¶
问答型 NLI 数据集包含英文句子对。在标签列中,0
表示句子与问题不相关,1
表示句子与问题相关。
import datasets
from datasets import load_dataset
datasets.logging.disable_progress_bar()
dataset = load_dataset("glue", "qnli")
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In[2], line 6
2 from datasets import load_dataset
4 datasets.logging.disable_progress_bar()
----> 6 dataset = load_dataset("glue", "qnli")
File ~/opt/venv/lib/python3.11/site-packages/datasets/load.py:2112, in load_dataset(path, name, data_dir, data_files, split, cache_dir, features, download_config, download_mode, verification_mode, ignore_verifications, keep_in_memory, save_infos, revision, token, use_auth_token, task, streaming, num_proc, storage_options, **config_kwargs)
2107 verification_mode = VerificationMode(
2108 (verification_mode or VerificationMode.BASIC_CHECKS) if not save_infos else VerificationMode.ALL_CHECKS
2109 )
2111 # Create a dataset builder
-> 2112 builder_instance = load_dataset_builder(
2113 path=path,
2114 name=name,
2115 data_dir=data_dir,
2116 data_files=data_files,
2117 cache_dir=cache_dir,
2118 features=features,
2119 download_config=download_config,
2120 download_mode=download_mode,
2121 revision=revision,
2122 token=token,
2123 storage_options=storage_options,
2124 **config_kwargs,
2125 )
2127 # Return iterable dataset in case of streaming
2128 if streaming:
File ~/opt/venv/lib/python3.11/site-packages/datasets/load.py:1798, in load_dataset_builder(path, name, data_dir, data_files, cache_dir, features, download_config, download_mode, revision, token, use_auth_token, storage_options, **config_kwargs)
1796 download_config = download_config.copy() if download_config else DownloadConfig()
1797 download_config.storage_options.update(storage_options)
-> 1798 dataset_module = dataset_module_factory(
1799 path,
1800 revision=revision,
1801 download_config=download_config,
1802 download_mode=download_mode,
1803 data_dir=data_dir,
1804 data_files=data_files,
1805 )
1806 # Get dataset builder class from the processing script
1807 builder_kwargs = dataset_module.builder_kwargs
File ~/opt/venv/lib/python3.11/site-packages/datasets/load.py:1495, in dataset_module_factory(path, revision, download_config, download_mode, dynamic_modules_path, data_dir, data_files, **download_kwargs)
1490 if isinstance(e1, FileNotFoundError):
1491 raise FileNotFoundError(
1492 f"Couldn't find a dataset script at {relative_to_absolute_path(combined_path)} or any data file in the same directory. "
1493 f"Couldn't find '{path}' on the Hugging Face Hub either: {type(e1).__name__}: {e1}"
1494 ) from None
-> 1495 raise e1 from None
1496 else:
1497 raise FileNotFoundError(
1498 f"Couldn't find a dataset script at {relative_to_absolute_path(combined_path)} or any data file in the same directory."
1499 )
File ~/opt/venv/lib/python3.11/site-packages/datasets/load.py:1479, in dataset_module_factory(path, revision, download_config, download_mode, dynamic_modules_path, data_dir, data_files, **download_kwargs)
1464 return HubDatasetModuleFactoryWithScript(
1465 path,
1466 revision=revision,
(...)
1469 dynamic_modules_path=dynamic_modules_path,
1470 ).get_module()
1471 else:
1472 return HubDatasetModuleFactoryWithoutScript(
1473 path,
1474 revision=revision,
1475 data_dir=data_dir,
1476 data_files=data_files,
1477 download_config=download_config,
1478 download_mode=download_mode,
-> 1479 ).get_module()
1480 except (
1481 Exception
1482 ) as e1: # noqa all the attempts failed, before raising the error we should check if the module is already cached.
1483 try:
File ~/opt/venv/lib/python3.11/site-packages/datasets/load.py:1034, in HubDatasetModuleFactoryWithoutScript.get_module(self)
1029 metadata_configs = MetadataConfigs.from_dataset_card_data(dataset_card_data)
1030 dataset_infos = DatasetInfosDict.from_dataset_card_data(dataset_card_data)
1031 patterns = (
1032 sanitize_patterns(self.data_files)
1033 if self.data_files is not None
-> 1034 else get_data_patterns(base_path, download_config=self.download_config)
1035 )
1036 data_files = DataFilesDict.from_patterns(
1037 patterns,
1038 base_path=base_path,
1039 allowed_extensions=ALL_ALLOWED_EXTENSIONS,
1040 download_config=self.download_config,
1041 )
1042 module_name, default_builder_kwargs = infer_module_for_data_files(
1043 data_files=data_files,
1044 path=self.name,
1045 download_config=self.download_config,
1046 )
File ~/opt/venv/lib/python3.11/site-packages/datasets/data_files.py:457, in get_data_patterns(base_path, download_config)
455 resolver = partial(resolve_pattern, base_path=base_path, download_config=download_config)
456 try:
--> 457 return _get_data_files_patterns(resolver)
458 except FileNotFoundError:
459 raise EmptyDatasetError(f"The directory at {base_path} doesn't contain any data files") from None
File ~/opt/venv/lib/python3.11/site-packages/datasets/data_files.py:248, in _get_data_files_patterns(pattern_resolver)
246 for pattern in patterns:
247 try:
--> 248 data_files = pattern_resolver(pattern)
249 except FileNotFoundError:
250 continue
File ~/opt/venv/lib/python3.11/site-packages/datasets/data_files.py:332, in resolve_pattern(pattern, base_path, allowed_extensions, download_config)
330 base_path = ""
331 pattern, storage_options = _prepare_path_and_storage_options(pattern, download_config=download_config)
--> 332 fs, _, _ = get_fs_token_paths(pattern, storage_options=storage_options)
333 fs_base_path = base_path.split("::")[0].split("://")[-1] or fs.root_marker
334 fs_pattern = pattern.split("::")[0].split("://")[-1]
File ~/opt/venv/lib/python3.11/site-packages/fsspec/core.py:686, in get_fs_token_paths(urlpath, mode, num, name_function, storage_options, protocol, expand)
684 paths = _expand_paths(paths, name_function, num)
685 elif "*" in paths:
--> 686 paths = [f for f in sorted(fs.glob(paths)) if not fs.isdir(f)]
687 else:
688 paths = [paths]
File ~/opt/venv/lib/python3.11/site-packages/huggingface_hub/hf_file_system.py:521, in HfFileSystem.glob(self, path, **kwargs)
519 kwargs = {"expand_info": kwargs.get("detail", False), **kwargs}
520 path = self.resolve_path(path, revision=kwargs.get("revision")).unresolve()
--> 521 return super().glob(path, **kwargs)
File ~/opt/venv/lib/python3.11/site-packages/fsspec/spec.py:611, in AbstractFileSystem.glob(self, path, maxdepth, **kwargs)
607 depth = None
609 allpaths = self.find(root, maxdepth=depth, withdirs=True, detail=True, **kwargs)
--> 611 pattern = glob_translate(path + ("/" if ends_with_sep else ""))
612 pattern = re.compile(pattern)
614 out = {
615 p: info
616 for p, info in sorted(allpaths.items())
(...)
621 )
622 }
File ~/opt/venv/lib/python3.11/site-packages/fsspec/utils.py:731, in glob_translate(pat)
729 continue
730 elif "**" in part:
--> 731 raise ValueError(
732 "Invalid pattern: '**' can only be an entire path component"
733 )
734 if part:
735 results.extend(_translate(part, f"{not_sep}*", not_sep))
ValueError: Invalid pattern: '**' can only be an entire path component
dataset['train']
from sklearn.model_selection import train_test_split
train_valid_df = dataset["train"].to_pandas()[["question", "sentence", "label"]].sample(1000, random_state=123)
train_df, valid_df = train_test_split(train_valid_df, test_size=0.2, random_state=123)
test_df = dataset["validation"].to_pandas()[["question", "sentence", "label"]].sample(1000, random_state=123)
加载教师模型¶
在我们的示例中,我们将直接加载一个使用 google/bert_uncased_L-12_H-768_A-12 主干并在 QNLI 上训练过的教师模型,并将其蒸馏到一个使用 google/bert_uncased_L-6_H-768_A-12 主干的学生模型中。
!wget --quiet https://automl-mm-bench.s3.amazonaws.com/unit-tests/distillation_sample_teacher.zip -O distillation_sample_teacher.zip
!unzip -q -o distillation_sample_teacher.zip -d .
from autogluon.multimodal import MultiModalPredictor
teacher_predictor = MultiModalPredictor.load("ag_distillation_sample_teacher/")
蒸馏到学生模型¶
训练学生模型非常简单。您只需在调用 .fit()
时添加 teacher_predictor
参数即可。在内部,学生模型将通过匹配教师模型的预测/特征图进行训练。这比直接微调学生模型的性能更好。
student_predictor = MultiModalPredictor(label="label")
student_predictor.fit(
train_df,
tuning_data=valid_df,
teacher_predictor=teacher_predictor,
hyperparameters={
"model.hf_text.checkpoint_name": "google/bert_uncased_L-6_H-768_A-12",
"optim.max_epochs": 2,
}
)
print(student_predictor.evaluate(data=test_df))
更多关于知识蒸馏的信息¶
要了解如何自定义蒸馏以及它与直接微调的比较,请参阅 AutoMM 蒸馏示例 中的蒸馏示例和 README。特别是包含更多详细信息和自定义的 多语言蒸馏示例。
其他示例¶
您可以前往 AutoMM 示例 探索其他 AutoMM 示例。
自定义¶
要了解如何自定义 AutoMM,请参考自定义 AutoMM。