TabularPredictor.distill¶
- TabularPredictor.distill(train_data: DataFrame | str = None, tuning_data: DataFrame | str = None, augmentation_data: DataFrame = None, time_limit: float = None, hyperparameters: dict | str = None, holdout_frac: float = None, teacher_preds: str = 'soft', augment_method: str = 'spunge', augment_args: dict = {'max_size': 100000, 'size_factor': 5}, models_name_suffix: str = None, verbosity: int = None)[source]¶
[实验性] 将 AutoGluon 最准确的集成预测器蒸馏为更简单/更快速、需要更少内存/计算的单个模型。蒸馏产生的模型可能比直接使用原始训练数据训练的相同模型更准确。调用 distill() 后,此 Predictor 中将提供更多模型,这些模型可以使用 predictor.leaderboard(test_data) 进行评估,并使用 predictor.predict(test_data, model=MODEL_NAME) 进行部署。如果之前在 fit() 中设置了 cache_data=False,则会引发异常。
注意:在 catboost v0.24 发布之前,在多类别分类中使用 CatBoost 作为学生模型进行 distill() 需要您先安装 catboost-dev:pip install catboost-dev
- 参数:
train_data (str 或
pd.DataFrame
, 默认值 = None) – 与 fit() 的 train_data 参数相同。如果为 None,将从用于生成此 Predictor 的 fit() 调用中加载相同的训练数据。tuning_data (str 或
pd.DataFrame
, 默认值 = None) – 与 fit() 的 tuning_data 参数相同。如果 tuning_data = None 且 train_data = None:将从用于生成此 Predictor 的 fit() 调用中加载相同的训练/验证分割,除非之前使用了 bagging/stacking,在这种情况下会执行新的训练/验证分割。augmentation_data (
pd.DataFrame
, 默认值 = None) – 一个可选的额外无标签行数据集,可用于在蒸馏过程中增强用于训练学生模型的数据集(如果为 None 则忽略)。time_limit (int, 默认值 = None) – 蒸馏过程应运行的大致时长(以秒为单位)。如果为 None,则不强制时间限制,允许蒸馏的模型充分训练。
hyperparameters (dict 或 str, 默认值 = None) – 指定要用作学生模型的模型以及它们使用的超参数值。与 fit() 的 hyperparameters 参数相同。如果 = None,则学生模型将使用用于生成此 Predictor 的 fit() 中的相同超参数。注意:蒸馏目前仅支持 [‘GBM’,’NN_TORCH’,’RF’,’CAT’] 学生模型,其他模型及其超参数在此处会被忽略。
holdout_frac (float) – 与
TabularPredictor.fit()
的 holdout_frac 参数相同。teacher_preds (str, 默认值 = 'soft') – 指定从何种形式的教师预测进行蒸馏(教师指最准确的 AutoGluon 集成预测器)。如果为 None,我们只使用原始标签进行训练(无数据增强)。如果为 ‘hard’,标签是硬教师预测,由 teacher.predict() 给出。如果为 ‘soft’,标签是软教师预测,由 teacher.predict_proba() 给出。注意:对于回归问题,‘hard’ 和 ‘soft’ 是等效的。如果 augment_method 不为 None,教师预测仅用于标注增强数据(训练数据保留原始标签)。要应用标签平滑:teacher_preds=’onehot’ 将使用原始训练数据标签转换为 one-hot 向量用于多类别问题(无数据增强)。
augment_method (str, 默认值='spunge') –
指定用于生成增强数据以蒸馏学生模型的方法。选项包括
None : 不执行数据增强。‘munge’ : MUNGE 算法 (https://www.cs.cornell.edu/~caruana/compression.kdd06.pdf)。‘spunge’ : MUNGE 算法的一个更简单、更高效的变体。
augment_args (dict, 默认值 = {'size_factor':5, 'max_size': int(1e5)}) –
- 包含控制所选 augment_method 的以下 kwargs(如果 augment_method=None 则忽略)
’num_augmented_samples’: int,蒸馏期间使用的增强数据点数量。如果指定,则覆盖 ‘size_factor’ 和 ‘max_size’。‘max_size’: float,要添加的最大增强数据点数量(如果指定了 ‘num_augmented_samples’ 则忽略)。‘size_factor’: float,如果 n = 训练数据样本大小,我们最多添加 int(n * size_factor) 个增强数据点,上限为 ‘max_size’。’augment_args’ 中的较大值会减慢 distill() 的运行时间,如果提供的 time_limit 太小,可能会产生更差的结果。您还可以传入 kwargs 给 autogluon.tabular.augmentation.distill_utils 中的 spunge_augment、munge_augment 函数。
models_name_suffix (str, 默认值 = None) – 可选后缀,可以附加在所有蒸馏学生模型的名称末尾。注意:所有蒸馏模型的名称默认会包含 ‘_DSTL’ 子字符串。
verbosity (int, 默认值 = None) – 控制蒸馏期间打印输出的数量(4 = 最高,0 = 最低)。与
TabularPredictor
的 verbosity 参数相同。如果为 None,则再次使用上次 fit 中使用的 verbosity。
- 返回类型:
与蒸馏模型对应的名称列表 (str)。
示例
>>> from autogluon.tabular import TabularDataset, TabularPredictor >>> train_data = TabularDataset('train.csv') >>> predictor = TabularPredictor(label='class').fit(train_data, auto_stack=True) >>> distilled_model_names = predictor.distill() >>> test_data = TabularDataset('test.csv') >>> ldr = predictor.leaderboard(test_data) >>> model_to_deploy = distilled_model_names[0] >>> predictor.predict(test_data, model=model_to_deploy)