将自定义模型添加到 AutoGluon

Open In Colab Open In SageMaker Studio Lab

提示: 如果您刚开始使用 AutoGluon,请查看预测表格中的列 - 快速入门以了解 AutoGluon API 的基础知识。

本教程介绍如何将自定义模型添加到 AutoGluon,以便与默认模型(默认模型文档)一起进行训练、超参数调优和集成。

在此示例中,我们创建一个自定义的随机森林模型,用于 AutoGluon。AutoGluon 中的所有模型都继承自 AbstractModel 类(AbstractModel 源代码),并且必须遵循其 API 才能与其他模型协同工作。

请注意,虽然本教程提供了一个基本的模型实现,但并未涵盖大多数已实现模型中使用的许多方面。

要更好地理解如何实现更高级的功能,请参考以下模型的源代码

功能

参考实现

遵守时间限制 / 提前停止逻辑

LGBModelRFModel

遵守内存使用限制

LGBModel 和 RFModel

支持样本权重

LGBModel

验证数据和 eval_metric 用法

LGBModel

支持 GPU 训练

LGBModel

非序列化模型的保存 / 加载逻辑

NNFastAiTabularModel

支持高级问题类型 (Softclass, Quantile)

RFModel

支持文本特征类型

TextPredictorModel

支持图像特征类型

ImagePredictorModel

包依赖的延迟导入

LGBModel

自定义 HPO 逻辑

LGBModel

实现自定义模型

在这里,我们定义将在本教程其余部分使用的自定义模型。

必须实现的最重要方法是 _fit_preprocess

要与 AutoGluon 官方的随机森林实现进行比较,请参阅 RFModel 源代码。

请结合代码注释来更好地理解代码的工作原理。

import numpy as np
import pandas as pd

from autogluon.core.models import AbstractModel
from autogluon.features.generators import LabelEncoderFeatureGenerator

class CustomRandomForestModel(AbstractModel):
    def __init__(self, **kwargs):
        # Simply pass along kwargs to parent, and init our internal `_feature_generator` variable to None
        super().__init__(**kwargs)
        self._feature_generator = None

    # The `_preprocess` method takes the input data and transforms it to the internal representation usable by the model.
    # `_preprocess` is called by `preprocess` and is used during model fit and model inference.
    def _preprocess(self, X: pd.DataFrame, is_train=False, **kwargs) -> np.ndarray:
        print(f'Entering the `_preprocess` method: {len(X)} rows of data (is_train={is_train})')
        X = super()._preprocess(X, **kwargs)

        if is_train:
            # X will be the training data.
            self._feature_generator = LabelEncoderFeatureGenerator(verbosity=0)
            self._feature_generator.fit(X=X)
        if self._feature_generator.features_in:
            # This converts categorical features to numeric via stateful label encoding.
            X = X.copy()
            X[self._feature_generator.features_in] = self._feature_generator.transform(X=X)
        # Add a fillna call to handle missing values.
        # Some algorithms will be able to handle NaN values internally (LightGBM).
        # In those cases, you can simply pass the NaN values into the inner model.
        # Finally, convert to numpy for optimized memory usage and because sklearn RF works with raw numpy input.
        return X.fillna(0).to_numpy(dtype=np.float32)

    # The `_fit` method takes the input training data (and optionally the validation data) and trains the model.
    def _fit(self,
             X: pd.DataFrame,  # training data
             y: pd.Series,  # training labels
             # X_val=None,  # val data (unused in RF model)
             # y_val=None,  # val labels (unused in RF model)
             # time_limit=None,  # time limit in seconds (ignored in tutorial)
             **kwargs):  # kwargs includes many other potential inputs, refer to AbstractModel documentation for details
        print('Entering the `_fit` method')

        # First we import the required dependencies for the model. Note that we do not import them outside of the method.
        # This enables AutoGluon to be highly extensible and modular.
        # For an example of best practices when importing model dependencies, refer to LGBModel.
        from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor

        # Valid self.problem_type values include ['binary', 'multiclass', 'regression', 'quantile', 'softclass']
        if self.problem_type in ['regression', 'softclass']:
            model_cls = RandomForestRegressor
        else:
            model_cls = RandomForestClassifier

        # Make sure to call preprocess on X near the start of `_fit`.
        # This is necessary because the data is converted via preprocess during predict, and needs to be in the same format as during fit.
        X = self.preprocess(X, is_train=True)
        # This fetches the user-specified (and default) hyperparameters for the model.
        params = self._get_model_params()
        print(f'Hyperparameters: {params}')
        # self.model should be set to the trained inner model, so that internally during predict we can call `self.model.predict(...)`
        self.model = model_cls(**params)
        self.model.fit(X, y)
        print('Exiting the `_fit` method')

    # The `_set_default_params` method defines the default hyperparameters of the model.
    # User-specified parameters will override these values on a key-by-key basis.
    def _set_default_params(self):
        default_params = {
            'n_estimators': 300,
            'n_jobs': -1,
            'random_state': 0,
        }
        for param, val in default_params.items():
            self._set_default_param_value(param, val)

    # The `_get_default_auxiliary_params` method defines various model-agnostic parameters such as maximum memory usage and valid input column dtypes.
    # For most users who build custom models, they will only need to specify the valid/invalid dtypes to the model here.
    def _get_default_auxiliary_params(self) -> dict:
        default_auxiliary_params = super()._get_default_auxiliary_params()
        extra_auxiliary_params = dict(
            # the total set of raw dtypes are: ['int', 'float', 'category', 'object', 'datetime']
            # object feature dtypes include raw text and image paths, which should only be handled by specialized models
            # datetime raw dtypes are generally converted to int in upstream pre-processing,
            # so models generally shouldn't need to explicitly support datetime dtypes.
            valid_raw_types=['int', 'float', 'category'],
            # Other options include `valid_special_types`, `ignored_type_group_raw`, and `ignored_type_group_special`.
            # Refer to AbstractModel for more details on available options.
        )
        default_auxiliary_params.update(extra_auxiliary_params)
        return default_auxiliary_params

加载数据

接下来我们将加载数据。在本教程中,我们将使用成人收入数据集,因为它包含整数、浮点数和分类特征的混合。

from autogluon.tabular import TabularDataset

train_data = TabularDataset('https://autogluon.s3.amazonaws.com/datasets/Inc/train.csv')  # can be local CSV file as well, returns Pandas DataFrame
test_data = TabularDataset('https://autogluon.s3.amazonaws.com/datasets/Inc/test.csv')  # another Pandas DataFrame
label = 'class'  # specifies which column do we want to predict
train_data = train_data.sample(n=1000, random_state=0)  # subsample for faster demo

train_data.head(5)
age workclass fnlwgt education education-num marital-status occupation relationship race sex capital-gain capital-loss hours-per-week native-country class
6118 51 Private 39264 Some-college 10 Married-civ-spouse Exec-managerial Wife White Female 0 0 40 United-States >50K
23204 58 Private 51662 10th 6 Married-civ-spouse Other-service Wife White Female 0 0 8 United-States <=50K
29590 40 Private 326310 Some-college 10 Married-civ-spouse Craft-repair Husband White Male 0 0 44 United-States <=50K
18116 37 Private 222450 HS-grad 9 Never-married Sales Not-in-family White Male 0 2339 40 El-Salvador <=50K
33964 62 Private 109190 Bachelors 13 Married-civ-spouse Exec-managerial Husband White Male 15024 0 40 United-States >50K

在没有 TabularPredictor 的情况下训练自定义模型

下面我们将演示如何在不使用 TabularPredictor 的情况下训练模型。这对于调试以及在实现模型时最小化需要理解的代码量非常有用。

这个过程类似于调用 TabularPredictor 的 fit 方法时内部发生的事情,但更加简化和精简。

如果数据已经被清理过(全部是数字),那么我们可以直接使用数据调用 fit 方法,但成人数据集不是这样。

清理标签

使输入数据对模型有效的第一步是清理标签。

目前,它们是字符串,但对于二分类,我们需要将它们转换为数值(0 和 1)。

幸运的是,AutoGluon 已经实现了逻辑,既可以检测这是二分类问题(通过 infer_problem_type),也可以将标签映射到 0 和 1 的转换器(LabelCleaner

# Separate features and labels
X = train_data.drop(columns=[label])
y = train_data[label]
X_test = test_data.drop(columns=[label])
y_test = test_data[label]

from autogluon.core.data import LabelCleaner
from autogluon.core.utils import infer_problem_type
# Construct a LabelCleaner to neatly convert labels to float/integers during model training/inference, can also use to inverse_transform back to original.
problem_type = infer_problem_type(y=y)  # Infer problem type (or else specify directly)
label_cleaner = LabelCleaner.construct(problem_type=problem_type, y=y)
y_clean = label_cleaner.transform(y)

print(f'Labels cleaned: {label_cleaner.inv_map}')
print(f'inferred problem type as: {problem_type}')
print('Cleaned label values:')
y_clean.head(5)
Labels cleaned: {' <=50K': 0, ' >50K': 1}
inferred problem type as: binary
Cleaned label values:
6118     1
23204    0
29590    0
18116    0
33964    1
Name: class, dtype: uint8

清理特征

接下来,我们需要清理特征。目前,像“workclass”这样的特征是对象 dtype(字符串),但我们实际上希望将它们用作分类特征。大多数模型不接受字符串输入,所以我们需要将字符串转换为数字。

AutoGluon 包含一个专门用于清理、转换和生成特征的整个模块,称为 autogluon.features。在这里,我们将使用 TabularPredictor 内部使用的相同特征生成器,将对象 dtype 转换为分类类型并最小化内存使用。

from autogluon.common.utils.log_utils import set_logger_verbosity
from autogluon.features.generators import AutoMLPipelineFeatureGenerator
set_logger_verbosity(2)  # Set logger so more detailed logging is shown for tutorial

feature_generator = AutoMLPipelineFeatureGenerator()
X_clean = feature_generator.fit_transform(X)

X_clean.head(5)
Fitting AutoMLPipelineFeatureGenerator...
	Available Memory:                    29463.95 MB
	Train Data (Original)  Memory Usage: 0.57 MB (0.0% of available memory)
	Inferring data type of each feature based on column values. Set feature_metadata_in to manually specify special dtypes of the features.
	Stage 1 Generators:
		Fitting AsTypeFeatureGenerator...
			Note: Converting 1 features to boolean dtype as they only contain 2 unique values.
	Stage 2 Generators:
		Fitting FillNaFeatureGenerator...
	Stage 3 Generators:
		Fitting IdentityFeatureGenerator...
		Fitting CategoryFeatureGenerator...
			Fitting CategoryMemoryMinimizeFeatureGenerator...
	Stage 4 Generators:
		Fitting DropUniqueFeatureGenerator...
	Stage 5 Generators:
		Fitting DropDuplicatesFeatureGenerator...
	Types of features in original data (raw dtype, special dtypes):
		('int', [])    : 6 | ['age', 'fnlwgt', 'education-num', 'capital-gain', 'capital-loss', ...]
		('object', []) : 8 | ['workclass', 'education', 'marital-status', 'occupation', 'relationship', ...]
	Types of features in processed data (raw dtype, special dtypes):
		('category', [])  : 7 | ['workclass', 'education', 'marital-status', 'occupation', 'relationship', ...]
		('int', [])       : 6 | ['age', 'fnlwgt', 'education-num', 'capital-gain', 'capital-loss', ...]
		('int', ['bool']) : 1 | ['sex']
	0.1s = Fit runtime
	14 features in original data used to generate 14 features in processed data.
	Train Data (Processed) Memory Usage: 0.06 MB (0.0% of available memory)
age fnlwgt education-num sex capital-gain capital-loss hours-per-week workclass education marital-status occupation relationship race native-country
6118 51 39264 10 0 0 0 40 3 14 1 4 5 4 24
23204 58 51662 6 0 0 0 8 3 0 1 8 5 4 24
29590 40 326310 10 1 0 0 44 3 14 1 3 0 4 24
18116 37 222450 9 1 0 2339 40 3 11 3 12 1 4 6
33964 62 109190 13 1 15024 0 40 3 9 1 4 0 4 24

AutoMLPipelineFeatureGenerator 不会为数值特征填充缺失值,也不会对数值特征进行重新缩放或对分类特征进行 one-hot 编码。如果模型需要这些操作,您需要在 _preprocess 方法中添加这些操作,并且可能会发现某些 FeatureGenerator 类对此有用。

拟合模型

我们现在可以使用清理后的特征和标签来拟合模型了。

custom_model = CustomRandomForestModel()
# We could also specify hyperparameters to override defaults
# custom_model = CustomRandomForestModel(hyperparameters={'max_depth': 10})
custom_model.fit(X=X_clean, y=y_clean)  # Fit custom model

# To save to disk and load the model, do the following:
# load_path = custom_model.path
# custom_model.save()
# del custom_model
# custom_model = CustomRandomForestModel.load(path=load_path)
Entering the `_fit` method
Entering the `_preprocess` method: 1000 rows of data (is_train=True)
Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0}
Exiting the `_fit` method
Warning: No name was specified for model, defaulting to class name: CustomRandomForestModel
No path specified. Models will be saved in: "AutogluonModels/ag-20250508_205558/CustomRandomForestModel"
Warning: No path was specified for model, defaulting to: /home/ci/autogluon/docs/tutorials/tabular/advanced/AutogluonModels/ag-20250508_205558/
Selected class <--> label mapping:  class 1 = 1, class 0 = 0
Model CustomRandomForestModel's eval_metric inferred to be 'accuracy' because problem_type='binary' and eval_metric was not specified during init.
<__main__.CustomRandomForestModel at 0x7f2ae4aa7990>

使用训练好的模型进行预测

现在模型已经拟合好了,我们可以对新数据进行预测。请记住,我们需要对新数据进行与训练数据相同的 数据和标签转换。

# Prepare test data
X_test_clean = feature_generator.transform(X_test)
y_test_clean = label_cleaner.transform(y_test)

X_test.head(5)
age workclass fnlwgt education education-num marital-status occupation relationship race sex capital-gain capital-loss hours-per-week native-country
0 31 Private 169085 11th 7 Married-civ-spouse Sales Wife White Female 0 0 20 United-States
1 17 Self-emp-not-inc 226203 12th 8 Never-married Sales Own-child White Male 0 0 45 United-States
2 47 Private 54260 Assoc-voc 11 Married-civ-spouse Exec-managerial Husband White Male 0 1887 60 United-States
3 21 Private 176262 Some-college 10 Never-married Exec-managerial Own-child White Female 0 0 30 United-States
4 17 Private 241185 12th 8 Never-married Prof-specialty Own-child White Male 0 0 20 United-States

获取测试数据的原始预测结果

y_pred = custom_model.predict(X_test_clean)
print(y_pred[:5])
Entering the `_preprocess` method: 9769 rows of data (is_train=False)
[0 0 1 0 0]

请注意,这些预测是正类的(即被推断为 1 的类别)。要获得更易于解释的结果,请执行以下操作:

y_pred_orig = label_cleaner.inverse_transform(y_pred)
y_pred_orig.head(5)
0     <=50K
1     <=50K
2      >50K
3     <=50K
4     <=50K
dtype: object

使用训练好的模型评分

默认情况下,模型具有特定于 problem_type 的 eval_metric。对于二分类,它使用准确率。

我们可以通过执行以下操作来获取模型的准确率得分:

score = custom_model.score(X_test_clean, y_test_clean)
print(f'Test score ({custom_model.eval_metric.name}) = {score}')
Entering the `_preprocess` method: 9769 rows of data (is_train=False)
Test score (accuracy) = 0.8424608455317842

在没有 TabularPredictor 的情况下训练一个 Bagging 自定义模型

AutoGluon 中一些更高级的功能,例如 Bagging,对于继承自 AbstractModel 的模型来说可以很容易实现。

您甚至可以用几行代码对您的自定义模型进行 Bagging。这是在几乎任何模型上获得质量提升的快速方法

from autogluon.core.models import BaggedEnsembleModel
bagged_custom_model = BaggedEnsembleModel(CustomRandomForestModel())
# Parallel folding currently doesn't work with a class not defined in a separate module because of underlying pickle serialization issue
# You don't need this following line if you put your custom model in a separate file and import it.
bagged_custom_model.params['fold_fitting_strategy'] = 'sequential_local' 
bagged_custom_model.fit(X=X_clean, y=y_clean, k_fold=10)  # Perform 10-fold bagging
bagged_score = bagged_custom_model.score(X_test_clean, y_test_clean)
print(f'Test score ({bagged_custom_model.eval_metric.name}) = {bagged_score} (bagged)')
print(f'Bagging increased model accuracy by {round(bagged_score - score, 4) * 100}%!')
Entering the `_fit` method
Entering the `_preprocess` method: 900 rows of data (is_train=True)
Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0}
Exiting the `_fit` method
Entering the `_preprocess` method: 100 rows of data (is_train=False)
Entering the `_fit` method
Entering the `_preprocess` method: 900 rows of data (is_train=True)
Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0}
Exiting the `_fit` method
Entering the `_preprocess` method: 100 rows of data (is_train=False)
Entering the `_fit` method
Entering the `_preprocess` method: 900 rows of data (is_train=True)
Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0}
Exiting the `_fit` method
Entering the `_preprocess` method: 100 rows of data (is_train=False)
Entering the `_fit` method
Entering the `_preprocess` method: 900 rows of data (is_train=True)
Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0}
Exiting the `_fit` method
Entering the `_preprocess` method: 100 rows of data (is_train=False)
Entering the `_fit` method
Entering the `_preprocess` method: 900 rows of data (is_train=True)
Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0}
Exiting the `_fit` method
Entering the `_preprocess` method: 100 rows of data (is_train=False)
Entering the `_fit` method
Entering the `_preprocess` method: 900 rows of data (is_train=True)
Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0}
Exiting the `_fit` method
Entering the `_preprocess` method: 100 rows of data (is_train=False)
Entering the `_fit` method
Entering the `_preprocess` method: 900 rows of data (is_train=True)
Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0}
Exiting the `_fit` method
Entering the `_preprocess` method: 100 rows of data (is_train=False)
Entering the `_fit` method
Entering the `_preprocess` method: 900 rows of data (is_train=True)
Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0}
Exiting the `_fit` method
Entering the `_preprocess` method: 100 rows of data (is_train=False)
Entering the `_fit` method
Entering the `_preprocess` method: 900 rows of data (is_train=True)
Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0}
Exiting the `_fit` method
Entering the `_preprocess` method: 100 rows of data (is_train=False)
Entering the `_fit` method
Entering the `_preprocess` method: 900 rows of data (is_train=True)
Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0}
Exiting the `_fit` method
Entering the `_preprocess` method: 100 rows of data (is_train=False)
Entering the `_preprocess` method: 9769 rows of data (is_train=False)
Entering the `_preprocess` method: 9769 rows of data (is_train=False)
Entering the `_preprocess` method: 9769 rows of data (is_train=False)
Entering the `_preprocess` method: 9769 rows of data (is_train=False)
Entering the `_preprocess` method: 9769 rows of data (is_train=False)
Entering the `_preprocess` method: 9769 rows of data (is_train=False)
Entering the `_preprocess` method: 9769 rows of data (is_train=False)
Entering the `_preprocess` method: 9769 rows of data (is_train=False)
Entering the `_preprocess` method: 9769 rows of data (is_train=False)
Entering the `_preprocess` method: 9769 rows of data (is_train=False)
Test score (accuracy) = 0.8435868563824342 (bagged)
Bagging increased model accuracy by 0.11%!
Warning: No name was specified for model, defaulting to class name: CustomRandomForestModel
No path specified. Models will be saved in: "AutogluonModels/ag-20250508_205558-001/CustomRandomForestModel"
Warning: No path was specified for model, defaulting to: /home/ci/autogluon/docs/tutorials/tabular/advanced/AutogluonModels/ag-20250508_205558-001/
Warning: No name was specified for model, defaulting to class name: BaggedEnsembleModel
No path specified. Models will be saved in: "AutogluonModels/ag-20250508_205558/BaggedEnsembleModel"
Warning: No path was specified for model, defaulting to: /home/ci/autogluon/docs/tutorials/tabular/advanced/AutogluonModels/ag-20250508_205558/
Selected class <--> label mapping:  class 1 = 1, class 0 = 0
Selected class <--> label mapping:  class 1 = 1, class 0 = 0
Model CustomRandomForestModel's eval_metric inferred to be 'accuracy' because problem_type='binary' and eval_metric was not specified during init.
Model 's eval_metric inferred to be 'accuracy' because problem_type='binary' and eval_metric was not specified during init.
	Fitting 10 child models (S1F1 - S1F10) | Fitting with SequentialLocalFoldFittingStrategy
Model S1F1's eval_metric inferred to be 'accuracy' because problem_type='binary' and eval_metric was not specified during init.
Model S1F2's eval_metric inferred to be 'accuracy' because problem_type='binary' and eval_metric was not specified during init.
Model S1F3's eval_metric inferred to be 'accuracy' because problem_type='binary' and eval_metric was not specified during init.
Model S1F4's eval_metric inferred to be 'accuracy' because problem_type='binary' and eval_metric was not specified during init.
Model S1F5's eval_metric inferred to be 'accuracy' because problem_type='binary' and eval_metric was not specified during init.
Model S1F6's eval_metric inferred to be 'accuracy' because problem_type='binary' and eval_metric was not specified during init.
Model S1F7's eval_metric inferred to be 'accuracy' because problem_type='binary' and eval_metric was not specified during init.
Model S1F8's eval_metric inferred to be 'accuracy' because problem_type='binary' and eval_metric was not specified during init.
Model S1F9's eval_metric inferred to be 'accuracy' because problem_type='binary' and eval_metric was not specified during init.
Model S1F10's eval_metric inferred to be 'accuracy' because problem_type='binary' and eval_metric was not specified during init.

请注意,bagging 模型在训练数据的不同分割上训练了 10 个 CustomRandomForestModels。进行预测时,bagging 模型会平均这 10 个模型的预测结果。

使用 TabularPredictor 训练自定义模型

虽然不使用 TabularPredictor 可以简化我们在开发和调试模型时需要关注的代码量,但最终我们希望利用 TabularPredictor 来最大化模型的潜力。

使用 TabularPredictor 训练原始数据的代码非常简单。无需指定 LabelCleaner、FeatureGenerator 或验证集,所有这些都在内部处理。

在这里,我们使用不同的超参数训练 3 个 CustomRandomForestModel。

from autogluon.tabular import TabularPredictor

# custom_hyperparameters = {CustomRandomForestModel: {}}  # train 1 CustomRandomForestModel Model with default hyperparameters
custom_hyperparameters = {CustomRandomForestModel: [{}, {'max_depth': 10}, {'max_features': 0.9, 'max_depth': 20}]}  # Train 3 CustomRandomForestModel with different hyperparameters
predictor = TabularPredictor(label=label).fit(train_data, hyperparameters=custom_hyperparameters)
Entering the `_fit` method
Entering the `_preprocess` method: 800 rows of data (is_train=True)
Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0}
Exiting the `_fit` method
Entering the `_preprocess` method: 200 rows of data (is_train=False)
Entering the `_fit` method
Entering the `_preprocess` method: 800 rows of data (is_train=True)
Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0, 'max_depth': 10}
Exiting the `_fit` method
Entering the `_preprocess` method: 200 rows of data (is_train=False)
Entering the `_fit` method
Entering the `_preprocess` method: 800 rows of data (is_train=True)
Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0, 'max_features': 0.9, 'max_depth': 20}
Exiting the `_fit` method
Entering the `_preprocess` method: 200 rows of data (is_train=False)
No path specified. Models will be saved in: "AutogluonModels/ag-20250508_205605"
Verbosity: 2 (Standard Logging)
=================== System Info ===================
AutoGluon Version:  1.3.1b20250508
Python Version:     3.11.9
Operating System:   Linux
Platform Machine:   x86_64
Platform Version:   #1 SMP Wed Mar 12 14:53:59 UTC 2025
CPU Count:          8
Memory Avail:       28.73 GB / 30.95 GB (92.8%)
Disk Space Avail:   212.07 GB / 255.99 GB (82.8%)
===================================================
No presets specified! To achieve strong results with AutoGluon, it is recommended to use the available presets. Defaulting to `'medium'`...
	Recommended Presets (For more details refer to https://autogluon.cn/stable/tutorials/tabular/tabular-essentials.html#presets):
	presets='experimental' : New in v1.2: Pre-trained foundation model + parallel fits. The absolute best accuracy without consideration for inference speed. Does not support GPU.
	presets='best'         : Maximize accuracy. Recommended for most users. Use in competitions and benchmarks.
	presets='high'         : Strong accuracy with fast inference speed.
	presets='good'         : Good accuracy with very fast inference speed.
	presets='medium'       : Fast training time, ideal for initial prototyping.
Beginning AutoGluon training ...
AutoGluon will save models to "/home/ci/autogluon/docs/tutorials/tabular/advanced/AutogluonModels/ag-20250508_205605"
Train Data Rows:    1000
Train Data Columns: 14
Label Column:       class
AutoGluon infers your prediction problem is: 'binary' (because only two unique label-values observed).
	2 unique label values:  [' >50K', ' <=50K']
	If 'binary' is not the correct problem_type, please manually specify the problem_type parameter during Predictor init (You may specify problem_type as one of: ['binary', 'multiclass', 'regression', 'quantile'])
Problem Type:       binary
Preprocessing data ...
Selected class <--> label mapping:  class 1 =  >50K, class 0 =  <=50K
	Note: For your binary classification, AutoGluon arbitrarily selected which label-value represents positive ( >50K) vs negative ( <=50K) class.
	To explicitly set the positive_class, either rename classes to 1 and 0, or specify positive_class in Predictor init.
Using Feature Generators to preprocess the data ...
Fitting AutoMLPipelineFeatureGenerator...
	Available Memory:                    29415.33 MB
	Train Data (Original)  Memory Usage: 0.56 MB (0.0% of available memory)
	Inferring data type of each feature based on column values. Set feature_metadata_in to manually specify special dtypes of the features.
	Stage 1 Generators:
		Fitting AsTypeFeatureGenerator...
			Note: Converting 1 features to boolean dtype as they only contain 2 unique values.
	Stage 2 Generators:
		Fitting FillNaFeatureGenerator...
	Stage 3 Generators:
		Fitting IdentityFeatureGenerator...
		Fitting CategoryFeatureGenerator...
			Fitting CategoryMemoryMinimizeFeatureGenerator...
	Stage 4 Generators:
		Fitting DropUniqueFeatureGenerator...
	Stage 5 Generators:
		Fitting DropDuplicatesFeatureGenerator...
	Types of features in original data (raw dtype, special dtypes):
		('int', [])    : 6 | ['age', 'fnlwgt', 'education-num', 'capital-gain', 'capital-loss', ...]
		('object', []) : 8 | ['workclass', 'education', 'marital-status', 'occupation', 'relationship', ...]
	Types of features in processed data (raw dtype, special dtypes):
		('category', [])  : 7 | ['workclass', 'education', 'marital-status', 'occupation', 'relationship', ...]
		('int', [])       : 6 | ['age', 'fnlwgt', 'education-num', 'capital-gain', 'capital-loss', ...]
		('int', ['bool']) : 1 | ['sex']
	0.1s = Fit runtime
	14 features in original data used to generate 14 features in processed data.
	Train Data (Processed) Memory Usage: 0.06 MB (0.0% of available memory)
Data preprocessing and feature engineering runtime = 0.09s ...
AutoGluon will gauge predictive performance using evaluation metric: 'accuracy'
	To change this, specify the eval_metric parameter of Predictor()
Automatically generating train/validation split with holdout_frac=0.2, Train Rows: 800, Val Rows: 200
User-specified model hyperparameters to be fit:
{
	'<class '__main__.CustomRandomForestModel'>': [{}, {'max_depth': 10}, {'max_features': 0.9, 'max_depth': 20}],
}
Custom Model Type Detected: <class '__main__.CustomRandomForestModel'>
Custom Model Type Detected: <class '__main__.CustomRandomForestModel'>
Custom Model Type Detected: <class '__main__.CustomRandomForestModel'>
Fitting 3 L1 models, fit_strategy="sequential" ...
Fitting model: CustomRandomForestModel ...
	0.835	 = Validation score   (accuracy)
	0.55s	 = Training   runtime
	0.06s	 = Validation runtime
Fitting model: CustomRandomForestModel_2 ...
	0.845	 = Validation score   (accuracy)
	0.52s	 = Training   runtime
	0.06s	 = Validation runtime
Fitting model: CustomRandomForestModel_3 ...
	0.84	 = Validation score   (accuracy)
	0.54s	 = Training   runtime
	0.06s	 = Validation runtime
Fitting model: WeightedEnsemble_L2 ...
	Ensemble Weights: {'CustomRandomForestModel_2': 1.0}
	0.845	 = Validation score   (accuracy)
	0.0s	 = Training   runtime
	0.0s	 = Validation runtime
AutoGluon training complete, total runtime = 1.96s ... Best model: WeightedEnsemble_L2 | Estimated inference throughput: 3342.2 rows/s (200 batch size)
Disabling decision threshold calibration for metric `accuracy` due to having fewer than 10000 rows of validation data for calibration, to avoid overfitting (200 rows).
	`accuracy` is generally not improved through threshold calibration. Force calibration via specifying `calibrate_decision_threshold=True`.
TabularPredictor saved. To load, use: predictor = TabularPredictor.load("/home/ci/autogluon/docs/tutorials/tabular/advanced/AutogluonModels/ag-20250508_205605")

预测器排行榜

这里显示了训练的每个模型的统计信息。请注意,还训练了一个 WeightedEnsemble 模型。该模型尝试通过集成来组合其他模型的预测结果,以获得更好的验证分数。

predictor.leaderboard(test_data)
Entering the `_preprocess` method: 9769 rows of data (is_train=False)
Entering the `_preprocess` method: 9769 rows of data (is_train=False)
Entering the `_preprocess` method: 9769 rows of data (is_train=False)
model score_test score_val eval_metric pred_time_test pred_time_val fit_time pred_time_test_marginal pred_time_val_marginal fit_time_marginal stack_level can_infer fit_order
0 CustomRandomForestModel_2 0.846044 0.845 accuracy 0.097788 0.059052 0.519446 0.097788 0.059052 0.519446 1 True 2
1 WeightedEnsemble_L2 0.846044 0.845 accuracy 0.100083 0.059841 0.522067 0.002295 0.000789 0.002620 2 True 4
2 CustomRandomForestModel 0.840414 0.835 accuracy 0.108968 0.058160 0.551860 0.108968 0.058160 0.551860 1 True 1
3 CustomRandomForestModel_3 0.828846 0.840 accuracy 0.098568 0.059306 0.544473 0.098568 0.059306 0.544473 1 True 3

使用拟合的预测器进行预测

这里我们使用拟合好的预测器进行预测。这将自动使用最佳模型(score_val 最高的那一个)进行预测。

y_pred = predictor.predict(test_data)
# y_pred = predictor.predict(test_data, model='CustomRandomForestModel_3')  # If we want a specific model to predict
y_pred.head(5)
Entering the `_preprocess` method: 9769 rows of data (is_train=False)
0     <=50K
1     <=50K
2      >50K
3     <=50K
4     <=50K
Name: class, dtype: object

使用 TabularPredictor 对自定义模型进行超参数调优

我们可以通过指定超参数搜索空间来代替精确值,从而轻松地对自定义模型进行超参数调优。

这里我们将自定义模型进行超参数调优 20 秒

from autogluon.common import space
custom_hyperparameters_hpo = {CustomRandomForestModel: {
    'max_depth': space.Int(lower=5, upper=30),
    'max_features': space.Real(lower=0.1, upper=1.0),
    'criterion': space.Categorical('gini', 'entropy'),
}}
# Hyperparameter tune CustomRandomForestModel for 20 seconds
predictor = TabularPredictor(label=label).fit(train_data,
                                              hyperparameters=custom_hyperparameters_hpo,
                                              hyperparameter_tune_kwargs='auto',  # enables HPO
                                              time_limit=20)
Entering the `_fit` method
Entering the `_preprocess` method: 800 rows of data (is_train=True)
Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0, 'max_depth': 5, 'max_features': 0.1, 'criterion': 'gini'}
Exiting the `_fit` method
Entering the `_preprocess` method: 200 rows of data (is_train=False)
Entering the `_fit` method
Entering the `_preprocess` method: 800 rows of data (is_train=True)
Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0, 'max_depth': 20, 'max_features': 0.7436704297351775, 'criterion': 'gini'}
Exiting the `_fit` method
Entering the `_preprocess` method: 200 rows of data (is_train=False)
Entering the `_fit` method
Entering the `_preprocess` method: 800 rows of data (is_train=True)
Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0, 'max_depth': 8, 'max_features': 0.8625265649057129, 'criterion': 'entropy'}
Exiting the `_fit` method
Entering the `_preprocess` method: 200 rows of data (is_train=False)
Entering the `_fit` method
Entering the `_preprocess` method: 800 rows of data (is_train=True)
Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0, 'max_depth': 26, 'max_features': 0.4459435365634299, 'criterion': 'entropy'}
Exiting the `_fit` method
Entering the `_preprocess` method: 200 rows of data (is_train=False)
Entering the `_fit` method
Entering the `_preprocess` method: 800 rows of data (is_train=True)
Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0, 'max_depth': 11, 'max_features': 0.15104167958569886, 'criterion': 'entropy'}
Exiting the `_fit` method
Entering the `_preprocess` method: 200 rows of data (is_train=False)
Entering the `_fit` method
Entering the `_preprocess` method: 800 rows of data (is_train=True)
Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0, 'max_depth': 6, 'max_features': 0.8125525342743981, 'criterion': 'gini'}
Exiting the `_fit` method
Entering the `_preprocess` method: 200 rows of data (is_train=False)
Entering the `_fit` method
Entering the `_preprocess` method: 800 rows of data (is_train=True)
Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0, 'max_depth': 19, 'max_features': 0.6112401049845391, 'criterion': 'entropy'}
Exiting the `_fit` method
Entering the `_preprocess` method: 200 rows of data (is_train=False)
Entering the `_fit` method
Entering the `_preprocess` method: 800 rows of data (is_train=True)
Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0, 'max_depth': 30, 'max_features': 0.16393245237809825, 'criterion': 'entropy'}
Exiting the `_fit` method
Entering the `_preprocess` method: 200 rows of data (is_train=False)
Entering the `_fit` method
Entering the `_preprocess` method: 800 rows of data (is_train=True)
Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0, 'max_depth': 25, 'max_features': 0.11819655769629316, 'criterion': 'entropy'}
Exiting the `_fit` method
Entering the `_preprocess` method: 200 rows of data (is_train=False)
Entering the `_fit` method
Entering the `_preprocess` method: 800 rows of data (is_train=True)
Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0, 'max_depth': 10, 'max_features': 0.8003410758548655, 'criterion': 'entropy'}
Exiting the `_fit` method
Entering the `_preprocess` method: 200 rows of data (is_train=False)
Entering the `_fit` method
Entering the `_preprocess` method: 800 rows of data (is_train=True)
Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0, 'max_depth': 5, 'max_features': 0.9807565080094875, 'criterion': 'entropy'}
Exiting the `_fit` method
Entering the `_preprocess` method: 200 rows of data (is_train=False)
Entering the `_fit` method
Entering the `_preprocess` method: 800 rows of data (is_train=True)
Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0, 'max_depth': 22, 'max_features': 0.5153314260276387, 'criterion': 'gini'}
Exiting the `_fit` method
Entering the `_preprocess` method: 200 rows of data (is_train=False)
Entering the `_fit` method
Entering the `_preprocess` method: 800 rows of data (is_train=True)
Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0, 'max_depth': 24, 'max_features': 0.20644698328203992, 'criterion': 'entropy'}
Exiting the `_fit` method
Entering the `_preprocess` method: 200 rows of data (is_train=False)
Entering the `_fit` method
Entering the `_preprocess` method: 800 rows of data (is_train=True)
Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0, 'max_depth': 6, 'max_features': 0.22901795866814179, 'criterion': 'gini'}
Exiting the `_fit` method
Entering the `_preprocess` method: 200 rows of data (is_train=False)
Entering the `_fit` method
Entering the `_preprocess` method: 800 rows of data (is_train=True)
Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0, 'max_depth': 5, 'max_features': 0.5696634895750645, 'criterion': 'entropy'}
Exiting the `_fit` method
Entering the `_preprocess` method: 200 rows of data (is_train=False)
Entering the `_fit` method
Entering the `_preprocess` method: 800 rows of data (is_train=True)
Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0, 'max_depth': 28, 'max_features': 0.3381000508941643, 'criterion': 'gini'}
Exiting the `_fit` method
Entering the `_preprocess` method: 200 rows of data (is_train=False)
Entering the `_fit` method
Entering the `_preprocess` method: 800 rows of data (is_train=True)
Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0, 'max_depth': 23, 'max_features': 0.5105352989948937, 'criterion': 'entropy'}
Exiting the `_fit` method
Entering the `_preprocess` method: 200 rows of data (is_train=False)
Entering the `_fit` method
Entering the `_preprocess` method: 800 rows of data (is_train=True)
Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0, 'max_depth': 5, 'max_features': 0.11691082039271963, 'criterion': 'gini'}
Exiting the `_fit` method
Entering the `_preprocess` method: 200 rows of data (is_train=False)
Entering the `_fit` method
Entering the `_preprocess` method: 800 rows of data (is_train=True)
Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0, 'max_depth': 10, 'max_features': 0.6508861504501793, 'criterion': 'entropy'}
Exiting the `_fit` method
Entering the `_preprocess` method: 200 rows of data (is_train=False)
Entering the `_fit` method
Entering the `_preprocess` method: 800 rows of data (is_train=True)
Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0, 'max_depth': 22, 'max_features': 0.9493732706631618, 'criterion': 'gini'}
Exiting the `_fit` method
Entering the `_preprocess` method: 200 rows of data (is_train=False)
Entering the `_fit` method
Entering the `_preprocess` method: 800 rows of data (is_train=True)
Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0, 'max_depth': 15, 'max_features': 0.42355711051640743, 'criterion': 'entropy'}
Exiting the `_fit` method
Entering the `_preprocess` method: 200 rows of data (is_train=False)
Entering the `_fit` method
Entering the `_preprocess` method: 800 rows of data (is_train=True)
Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0, 'max_depth': 6, 'max_features': 0.7278680763345383, 'criterion': 'gini'}
Exiting the `_fit` method
Entering the `_preprocess` method: 200 rows of data (is_train=False)
Entering the `_fit` method
Entering the `_preprocess` method: 800 rows of data (is_train=True)
Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0, 'max_depth': 30, 'max_features': 0.7000900439011009, 'criterion': 'entropy'}
Exiting the `_fit` method
Entering the `_preprocess` method: 200 rows of data (is_train=False)
Entering the `_fit` method
Entering the `_preprocess` method: 800 rows of data (is_train=True)
Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0, 'max_depth': 16, 'max_features': 0.2893443049664568, 'criterion': 'entropy'}
Exiting the `_fit` method
Entering the `_preprocess` method: 200 rows of data (is_train=False)
Entering the `_fit` method
Entering the `_preprocess` method: 800 rows of data (is_train=True)
Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0, 'max_depth': 5, 'max_features': 0.38388551583176544, 'criterion': 'entropy'}
Exiting the `_fit` method
Entering the `_preprocess` method: 200 rows of data (is_train=False)
Entering the `_fit` method
Entering the `_preprocess` method: 800 rows of data (is_train=True)
Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0, 'max_depth': 17, 'max_features': 0.6131770933760917, 'criterion': 'entropy'}
Exiting the `_fit` method
Entering the `_preprocess` method: 200 rows of data (is_train=False)
Entering the `_fit` method
Entering the `_preprocess` method: 800 rows of data (is_train=True)
Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0, 'max_depth': 16, 'max_features': 0.9895364542533036, 'criterion': 'gini'}
Exiting the `_fit` method
Entering the `_preprocess` method: 200 rows of data (is_train=False)
Entering the `_preprocess` method: 200 rows of data (is_train=False)
Entering the `_preprocess` method: 200 rows of data (is_train=False)
Entering the `_preprocess` method: 200 rows of data (is_train=False)
Entering the `_preprocess` method: 200 rows of data (is_train=False)
Entering the `_preprocess` method: 200 rows of data (is_train=False)
Entering the `_preprocess` method: 200 rows of data (is_train=False)
Entering the `_preprocess` method: 200 rows of data (is_train=False)
Entering the `_preprocess` method: 200 rows of data (is_train=False)
Entering the `_preprocess` method: 200 rows of data (is_train=False)
Entering the `_preprocess` method: 200 rows of data (is_train=False)
Entering the `_preprocess` method: 200 rows of data (is_train=False)
Entering the `_preprocess` method: 200 rows of data (is_train=False)
Entering the `_preprocess` method: 200 rows of data (is_train=False)
Entering the `_preprocess` method: 200 rows of data (is_train=False)
Entering the `_preprocess` method: 200 rows of data (is_train=False)
Entering the `_preprocess` method: 200 rows of data (is_train=False)
Entering the `_preprocess` method: 200 rows of data (is_train=False)
Entering the `_preprocess` method: 200 rows of data (is_train=False)
Entering the `_preprocess` method: 200 rows of data (is_train=False)
Entering the `_preprocess` method: 200 rows of data (is_train=False)
Entering the `_preprocess` method: 200 rows of data (is_train=False)
Entering the `_preprocess` method: 200 rows of data (is_train=False)
Entering the `_preprocess` method: 200 rows of data (is_train=False)
Entering the `_preprocess` method: 200 rows of data (is_train=False)
Entering the `_preprocess` method: 200 rows of data (is_train=False)
Entering the `_preprocess` method: 200 rows of data (is_train=False)
Entering the `_preprocess` method: 200 rows of data (is_train=False)
No path specified. Models will be saved in: "AutogluonModels/ag-20250508_205608"
Verbosity: 2 (Standard Logging)
=================== System Info ===================
AutoGluon Version:  1.3.1b20250508
Python Version:     3.11.9
Operating System:   Linux
Platform Machine:   x86_64
Platform Version:   #1 SMP Wed Mar 12 14:53:59 UTC 2025
CPU Count:          8
Memory Avail:       28.72 GB / 30.95 GB (92.8%)
Disk Space Avail:   212.06 GB / 255.99 GB (82.8%)
===================================================
No presets specified! To achieve strong results with AutoGluon, it is recommended to use the available presets. Defaulting to `'medium'`...
	Recommended Presets (For more details refer to https://autogluon.cn/stable/tutorials/tabular/tabular-essentials.html#presets):
	presets='experimental' : New in v1.2: Pre-trained foundation model + parallel fits. The absolute best accuracy without consideration for inference speed. Does not support GPU.
	presets='best'         : Maximize accuracy. Recommended for most users. Use in competitions and benchmarks.
	presets='high'         : Strong accuracy with fast inference speed.
	presets='good'         : Good accuracy with very fast inference speed.
	presets='medium'       : Fast training time, ideal for initial prototyping.
Warning: hyperparameter tuning is currently experimental and may cause the process to hang.
Beginning AutoGluon training ... Time limit = 20s
AutoGluon will save models to "/home/ci/autogluon/docs/tutorials/tabular/advanced/AutogluonModels/ag-20250508_205608"
Train Data Rows:    1000
Train Data Columns: 14
Label Column:       class
AutoGluon infers your prediction problem is: 'binary' (because only two unique label-values observed).
	2 unique label values:  [' >50K', ' <=50K']
	If 'binary' is not the correct problem_type, please manually specify the problem_type parameter during Predictor init (You may specify problem_type as one of: ['binary', 'multiclass', 'regression', 'quantile'])
Problem Type:       binary
Preprocessing data ...
Selected class <--> label mapping:  class 1 =  >50K, class 0 =  <=50K
	Note: For your binary classification, AutoGluon arbitrarily selected which label-value represents positive ( >50K) vs negative ( <=50K) class.
	To explicitly set the positive_class, either rename classes to 1 and 0, or specify positive_class in Predictor init.
Using Feature Generators to preprocess the data ...
Fitting AutoMLPipelineFeatureGenerator...
	Available Memory:                    29414.06 MB
	Train Data (Original)  Memory Usage: 0.56 MB (0.0% of available memory)
	Inferring data type of each feature based on column values. Set feature_metadata_in to manually specify special dtypes of the features.
	Stage 1 Generators:
		Fitting AsTypeFeatureGenerator...
			Note: Converting 1 features to boolean dtype as they only contain 2 unique values.
	Stage 2 Generators:
		Fitting FillNaFeatureGenerator...
	Stage 3 Generators:
		Fitting IdentityFeatureGenerator...
		Fitting CategoryFeatureGenerator...
			Fitting CategoryMemoryMinimizeFeatureGenerator...
	Stage 4 Generators:
		Fitting DropUniqueFeatureGenerator...
	Stage 5 Generators:
		Fitting DropDuplicatesFeatureGenerator...
	Types of features in original data (raw dtype, special dtypes):
		('int', [])    : 6 | ['age', 'fnlwgt', 'education-num', 'capital-gain', 'capital-loss', ...]
		('object', []) : 8 | ['workclass', 'education', 'marital-status', 'occupation', 'relationship', ...]
	Types of features in processed data (raw dtype, special dtypes):
		('category', [])  : 7 | ['workclass', 'education', 'marital-status', 'occupation', 'relationship', ...]
		('int', [])       : 6 | ['age', 'fnlwgt', 'education-num', 'capital-gain', 'capital-loss', ...]
		('int', ['bool']) : 1 | ['sex']
	0.1s = Fit runtime
	14 features in original data used to generate 14 features in processed data.
	Train Data (Processed) Memory Usage: 0.06 MB (0.0% of available memory)
Data preprocessing and feature engineering runtime = 0.09s ...
AutoGluon will gauge predictive performance using evaluation metric: 'accuracy'
	To change this, specify the eval_metric parameter of Predictor()
Automatically generating train/validation split with holdout_frac=0.2, Train Rows: 800, Val Rows: 200
User-specified model hyperparameters to be fit:
{
	'<class '__main__.CustomRandomForestModel'>': [{'max_depth': Int: lower=5, upper=30, 'max_features': Real: lower=0.1, upper=1.0, 'criterion': Categorical['gini', 'entropy']}],
}
Custom Model Type Detected: <class '__main__.CustomRandomForestModel'>
Fitting 1 L1 models, fit_strategy="sequential" ...
Hyperparameter tuning model: CustomRandomForestModel ... Tuning model for up to 17.92s of the 19.91s of remaining time.
	Stopping HPO to satisfy time limit...
Fitted model: CustomRandomForestModel/T1 ...
	0.805	 = Validation score   (accuracy)
	0.51s	 = Training   runtime
	0.06s	 = Validation runtime
Fitted model: CustomRandomForestModel/T2 ...
	0.835	 = Validation score   (accuracy)
	0.54s	 = Training   runtime
	0.06s	 = Validation runtime
Fitted model: CustomRandomForestModel/T3 ...
	0.825	 = Validation score   (accuracy)
	0.53s	 = Training   runtime
	0.06s	 = Validation runtime
Fitted model: CustomRandomForestModel/T4 ...
	0.855	 = Validation score   (accuracy)
	0.53s	 = Training   runtime
	0.06s	 = Validation runtime
Fitted model: CustomRandomForestModel/T5 ...
	0.835	 = Validation score   (accuracy)
	0.52s	 = Training   runtime
	0.06s	 = Validation runtime
Fitted model: CustomRandomForestModel/T6 ...
	0.83	 = Validation score   (accuracy)
	0.52s	 = Training   runtime
	0.06s	 = Validation runtime
Fitted model: CustomRandomForestModel/T7 ...
	0.845	 = Validation score   (accuracy)
	0.53s	 = Training   runtime
	0.06s	 = Validation runtime
Fitted model: CustomRandomForestModel/T8 ...
	0.845	 = Validation score   (accuracy)
	0.53s	 = Training   runtime
	0.06s	 = Validation runtime
Fitted model: CustomRandomForestModel/T9 ...
	0.835	 = Validation score   (accuracy)
	0.52s	 = Training   runtime
	0.06s	 = Validation runtime
Fitted model: CustomRandomForestModel/T10 ...
	0.845	 = Validation score   (accuracy)
	0.53s	 = Training   runtime
	0.06s	 = Validation runtime
Fitted model: CustomRandomForestModel/T11 ...
	0.85	 = Validation score   (accuracy)
	0.52s	 = Training   runtime
	0.06s	 = Validation runtime
Fitted model: CustomRandomForestModel/T12 ...
	0.835	 = Validation score   (accuracy)
	0.53s	 = Training   runtime
	0.06s	 = Validation runtime
Fitted model: CustomRandomForestModel/T13 ...
	0.84	 = Validation score   (accuracy)
	0.52s	 = Training   runtime
	0.06s	 = Validation runtime
Fitted model: CustomRandomForestModel/T14 ...
	0.835	 = Validation score   (accuracy)
	0.52s	 = Training   runtime
	0.06s	 = Validation runtime
Fitted model: CustomRandomForestModel/T15 ...
	0.845	 = Validation score   (accuracy)
	0.51s	 = Training   runtime
	0.06s	 = Validation runtime
Fitted model: CustomRandomForestModel/T16 ...
	0.85	 = Validation score   (accuracy)
	0.51s	 = Training   runtime
	0.06s	 = Validation runtime
Fitted model: CustomRandomForestModel/T17 ...
	0.85	 = Validation score   (accuracy)
	0.53s	 = Training   runtime
	0.06s	 = Validation runtime
Fitted model: CustomRandomForestModel/T18 ...
	0.805	 = Validation score   (accuracy)
	0.51s	 = Training   runtime
	0.06s	 = Validation runtime
Fitted model: CustomRandomForestModel/T19 ...
	0.845	 = Validation score   (accuracy)
	0.53s	 = Training   runtime
	0.06s	 = Validation runtime
Fitted model: CustomRandomForestModel/T20 ...
	0.835	 = Validation score   (accuracy)
	0.53s	 = Training   runtime
	0.06s	 = Validation runtime
Fitted model: CustomRandomForestModel/T21 ...
	0.85	 = Validation score   (accuracy)
	0.52s	 = Training   runtime
	0.05s	 = Validation runtime
Fitted model: CustomRandomForestModel/T22 ...
	0.83	 = Validation score   (accuracy)
	0.53s	 = Training   runtime
	0.06s	 = Validation runtime
Fitted model: CustomRandomForestModel/T23 ...
	0.84	 = Validation score   (accuracy)
	0.53s	 = Training   runtime
	0.06s	 = Validation runtime
Fitted model: CustomRandomForestModel/T24 ...
	0.845	 = Validation score   (accuracy)
	0.52s	 = Training   runtime
	0.06s	 = Validation runtime
Fitted model: CustomRandomForestModel/T25 ...
	0.845	 = Validation score   (accuracy)
	0.59s	 = Training   runtime
	0.06s	 = Validation runtime
Fitted model: CustomRandomForestModel/T26 ...
	0.845	 = Validation score   (accuracy)
	0.53s	 = Training   runtime
	0.06s	 = Validation runtime
Fitted model: CustomRandomForestModel/T27 ...
	0.835	 = Validation score   (accuracy)
	0.54s	 = Training   runtime
	0.06s	 = Validation runtime
Fitting model: WeightedEnsemble_L2 ... Training model for up to 19.91s of the -0.11s of remaining time.
	Ensemble Weights: {'CustomRandomForestModel/T4': 1.0}
	0.855	 = Validation score   (accuracy)
	0.0s	 = Training   runtime
	0.0s	 = Validation runtime
AutoGluon training complete, total runtime = 20.15s ... Best model: WeightedEnsemble_L2 | Estimated inference throughput: 3432.0 rows/s (200 batch size)
Disabling decision threshold calibration for metric `accuracy` due to having fewer than 10000 rows of validation data for calibration, to avoid overfitting (200 rows).
	`accuracy` is generally not improved through threshold calibration. Force calibration via specifying `calibrate_decision_threshold=True`.
TabularPredictor saved. To load, use: predictor = TabularPredictor.load("/home/ci/autogluon/docs/tutorials/tabular/advanced/AutogluonModels/ag-20250508_205608")

预测器排行榜 (HPO)

HPO 运行的排行榜将显示名称中带有后缀 '/Tx' 的模型。这表示它们是在哪个 HPO 试验中执行的。

leaderboard_hpo = predictor.leaderboard()
leaderboard_hpo
model score_val eval_metric pred_time_val fit_time pred_time_val_marginal fit_time_marginal stack_level can_infer fit_order
0 CustomRandomForestModel/T4 0.855 accuracy 0.057505 0.527691 0.057505 0.527691 1 True 4
1 WeightedEnsemble_L2 0.855 accuracy 0.058275 0.530208 0.000770 0.002517 2 True 28
2 CustomRandomForestModel/T21 0.850 accuracy 0.049113 0.516646 0.049113 0.516646 1 True 21
3 CustomRandomForestModel/T16 0.850 accuracy 0.057665 0.509290 0.057665 0.509290 1 True 16
4 CustomRandomForestModel/T17 0.850 accuracy 0.058555 0.527065 0.058555 0.527065 1 True 17
5 CustomRandomForestModel/T11 0.850 accuracy 0.060928 0.523556 0.060928 0.523556 1 True 11
6 CustomRandomForestModel/T24 0.845 accuracy 0.057226 0.519612 0.057226 0.519612 1 True 24
7 CustomRandomForestModel/T25 0.845 accuracy 0.057345 0.590406 0.057345 0.590406 1 True 25
8 CustomRandomForestModel/T19 0.845 accuracy 0.057515 0.530065 0.057515 0.530065 1 True 19
9 CustomRandomForestModel/T26 0.845 accuracy 0.057742 0.528203 0.057742 0.528203 1 True 26
10 CustomRandomForestModel/T15 0.845 accuracy 0.057770 0.507537 0.057770 0.507537 1 True 15
11 CustomRandomForestModel/T7 0.845 accuracy 0.058108 0.525845 0.058108 0.525845 1 True 7
12 CustomRandomForestModel/T10 0.845 accuracy 0.058333 0.533082 0.058333 0.533082 1 True 10
13 CustomRandomForestModel/T8 0.845 accuracy 0.059577 0.528987 0.059577 0.528987 1 True 8
14 CustomRandomForestModel/T23 0.840 accuracy 0.057624 0.528516 0.057624 0.528516 1 True 23
15 CustomRandomForestModel/T13 0.840 accuracy 0.057679 0.519133 0.057679 0.519133 1 True 13
16 CustomRandomForestModel/T20 0.835 accuracy 0.057330 0.534520 0.057330 0.534520 1 True 20
17 CustomRandomForestModel/T12 0.835 accuracy 0.057552 0.529909 0.057552 0.529909 1 True 12
18 CustomRandomForestModel/T14 0.835 accuracy 0.057713 0.516292 0.057713 0.516292 1 True 14
19 CustomRandomForestModel/T2 0.835 accuracy 0.058793 0.535451 0.058793 0.535451 1 True 2
20 CustomRandomForestModel/T27 0.835 accuracy 0.058857 0.536120 0.058857 0.536120 1 True 27
21 CustomRandomForestModel/T5 0.835 accuracy 0.058996 0.515126 0.058996 0.515126 1 True 5
22 CustomRandomForestModel/T9 0.835 accuracy 0.060557 0.519985 0.060557 0.519985 1 True 9
23 CustomRandomForestModel/T22 0.830 accuracy 0.058770 0.525597 0.058770 0.525597 1 True 22
24 CustomRandomForestModel/T6 0.830 accuracy 0.058837 0.524674 0.058837 0.524674 1 True 6
25 CustomRandomForestModel/T3 0.825 accuracy 0.057525 0.526406 0.057525 0.526406 1 True 3
26 CustomRandomForestModel/T1 0.805 accuracy 0.057927 0.506985 0.057927 0.506985 1 True 1
27 CustomRandomForestModel/T18 0.805 accuracy 0.058503 0.513918 0.058503 0.513918 1 True 18

获取训练好的模型的超参数

让我们获取验证得分最高的模型的超参数。

best_model_name = leaderboard_hpo[leaderboard_hpo['stack_level'] == 1]['model'].iloc[0]

predictor_info = predictor.info()
best_model_info = predictor_info['model_info'][best_model_name]

print(best_model_info)

print(f'Best Model Hyperparameters ({best_model_name}):')
print(best_model_info['hyperparameters'])
{'name': 'CustomRandomForestModel/T4', 'model_type': 'CustomRandomForestModel', 'problem_type': 'binary', 'eval_metric': 'accuracy', 'stopping_metric': 'accuracy', 'fit_time': 0.5276908874511719, 'num_classes': 2, 'quantile_levels': None, 'predict_time': 0.05750536918640137, 'val_score': 0.855, 'hyperparameters': {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0, 'max_depth': 26, 'max_features': 0.4459435365634299, 'criterion': 'entropy'}, 'hyperparameters_user': {'max_depth': 26, 'max_features': 0.4459435365634299, 'criterion': 'entropy', 'n_estimators': 300, 'n_jobs': -1, 'random_state': 0}, 'hyperparameters_fit': {}, 'hyperparameters_nondefault': ['max_depth', 'max_features', 'criterion', 'n_estimators', 'n_jobs', 'random_state'], 'ag_args_fit': {'max_memory_usage_ratio': 1.0, 'max_time_limit_ratio': 1.0, 'max_time_limit': None, 'min_time_limit': 0, 'valid_raw_types': ['int', 'float', 'category'], 'valid_special_types': None, 'ignored_type_group_special': None, 'ignored_type_group_raw': None, 'get_features_kwargs': None, 'get_features_kwargs_extra': None, 'predict_1_batch_size': None, 'temperature_scalar': None}, 'num_features': 14, 'features': ['age', 'fnlwgt', 'education-num', 'sex', 'capital-gain', 'capital-loss', 'hours-per-week', 'workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race', 'native-country'], 'feature_metadata': <autogluon.common.features.feature_metadata.FeatureMetadata object at 0x7f2ae329be50>, 'memory_size': 4803239, 'compile_time': None, 'is_initialized': True, 'is_fit': True, 'is_valid': True, 'can_infer': True, 'has_learning_curves': False, 'num_samples': 800, 'val_in_fit': True, 'unlabeled_in_fit': False, 'num_cpus': 8, 'num_gpus': 0.0}
Best Model Hyperparameters (CustomRandomForestModel/T4):
{'n_estimators': 300, 'n_jobs': -1, 'random_state': 0, 'max_depth': 26, 'max_features': 0.4459435365634299, 'criterion': 'entropy'}

使用 TabularPredictor 与其他模型一起训练自定义模型

最后,我们将自定义模型(具有调优的超参数)与默认的 AutoGluon 模型一起进行训练。

这只需要通过 get_hyperparameter_config 获取默认模型的超参数字典,并将 CustomRandomForestModel 添加为一个键。

from autogluon.tabular.configs.hyperparameter_configs import get_hyperparameter_config

# Now we can add the custom model with tuned hyperparameters to be trained alongside the default models:
custom_hyperparameters = get_hyperparameter_config('default')

custom_hyperparameters[CustomRandomForestModel] = best_model_info['hyperparameters']

print(custom_hyperparameters)
{'NN_TORCH': {}, 'GBM': [{'extra_trees': True, 'ag_args': {'name_suffix': 'XT'}}, {}, {'learning_rate': 0.03, 'num_leaves': 128, 'feature_fraction': 0.9, 'min_data_in_leaf': 3, 'ag_args': {'name_suffix': 'Large', 'priority': 0, 'hyperparameter_tune_kwargs': None}}], 'CAT': {}, 'XGB': {}, 'FASTAI': {}, 'RF': [{'criterion': 'gini', 'ag_args': {'name_suffix': 'Gini', 'problem_types': ['binary', 'multiclass']}}, {'criterion': 'entropy', 'ag_args': {'name_suffix': 'Entr', 'problem_types': ['binary', 'multiclass']}}, {'criterion': 'squared_error', 'ag_args': {'name_suffix': 'MSE', 'problem_types': ['regression', 'quantile']}}], 'XT': [{'criterion': 'gini', 'ag_args': {'name_suffix': 'Gini', 'problem_types': ['binary', 'multiclass']}}, {'criterion': 'entropy', 'ag_args': {'name_suffix': 'Entr', 'problem_types': ['binary', 'multiclass']}}, {'criterion': 'squared_error', 'ag_args': {'name_suffix': 'MSE', 'problem_types': ['regression', 'quantile']}}], 'KNN': [{'weights': 'uniform', 'ag_args': {'name_suffix': 'Unif'}}, {'weights': 'distance', 'ag_args': {'name_suffix': 'Dist'}}], <class '__main__.CustomRandomForestModel'>: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0, 'max_depth': 26, 'max_features': 0.4459435365634299, 'criterion': 'entropy'}}
predictor = TabularPredictor(label=label).fit(train_data, hyperparameters=custom_hyperparameters)  # Train the default models plus a single tuned CustomRandomForestModel
# predictor = TabularPredictor(label=label).fit(train_data, hyperparameters=custom_hyperparameters, presets='best_quality')  # We can even use the custom model in a multi-layer stack ensemble
predictor.leaderboard(test_data)
Entering the `_fit` method
Entering the `_preprocess` method: 800 rows of data (is_train=True)
Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0, 'max_depth': 26, 'max_features': 0.4459435365634299, 'criterion': 'entropy'}
Exiting the `_fit` method
Entering the `_preprocess` method: 200 rows of data (is_train=False)
Entering the `_preprocess` method: 9769 rows of data (is_train=False)
No path specified. Models will be saved in: "AutogluonModels/ag-20250508_205630"
Verbosity: 2 (Standard Logging)
=================== System Info ===================
AutoGluon Version:  1.3.1b20250508
Python Version:     3.11.9
Operating System:   Linux
Platform Machine:   x86_64
Platform Version:   #1 SMP Wed Mar 12 14:53:59 UTC 2025
CPU Count:          8
Memory Avail:       28.70 GB / 30.95 GB (92.7%)
Disk Space Avail:   211.97 GB / 255.99 GB (82.8%)
===================================================
No presets specified! To achieve strong results with AutoGluon, it is recommended to use the available presets. Defaulting to `'medium'`...
	Recommended Presets (For more details refer to https://autogluon.cn/stable/tutorials/tabular/tabular-essentials.html#presets):
	presets='experimental' : New in v1.2: Pre-trained foundation model + parallel fits. The absolute best accuracy without consideration for inference speed. Does not support GPU.
	presets='best'         : Maximize accuracy. Recommended for most users. Use in competitions and benchmarks.
	presets='high'         : Strong accuracy with fast inference speed.
	presets='good'         : Good accuracy with very fast inference speed.
	presets='medium'       : Fast training time, ideal for initial prototyping.
Beginning AutoGluon training ...
AutoGluon will save models to "/home/ci/autogluon/docs/tutorials/tabular/advanced/AutogluonModels/ag-20250508_205630"
Train Data Rows:    1000
Train Data Columns: 14
Label Column:       class
AutoGluon infers your prediction problem is: 'binary' (because only two unique label-values observed).
	2 unique label values:  [' >50K', ' <=50K']
	If 'binary' is not the correct problem_type, please manually specify the problem_type parameter during Predictor init (You may specify problem_type as one of: ['binary', 'multiclass', 'regression', 'quantile'])
Problem Type:       binary
Preprocessing data ...
Selected class <--> label mapping:  class 1 =  >50K, class 0 =  <=50K
	Note: For your binary classification, AutoGluon arbitrarily selected which label-value represents positive ( >50K) vs negative ( <=50K) class.
	To explicitly set the positive_class, either rename classes to 1 and 0, or specify positive_class in Predictor init.
Using Feature Generators to preprocess the data ...
Fitting AutoMLPipelineFeatureGenerator...
	Available Memory:                    29388.29 MB
	Train Data (Original)  Memory Usage: 0.56 MB (0.0% of available memory)
	Inferring data type of each feature based on column values. Set feature_metadata_in to manually specify special dtypes of the features.
	Stage 1 Generators:
		Fitting AsTypeFeatureGenerator...
			Note: Converting 1 features to boolean dtype as they only contain 2 unique values.
	Stage 2 Generators:
		Fitting FillNaFeatureGenerator...
	Stage 3 Generators:
		Fitting IdentityFeatureGenerator...
		Fitting CategoryFeatureGenerator...
			Fitting CategoryMemoryMinimizeFeatureGenerator...
	Stage 4 Generators:
		Fitting DropUniqueFeatureGenerator...
	Stage 5 Generators:
		Fitting DropDuplicatesFeatureGenerator...
	Types of features in original data (raw dtype, special dtypes):
		('int', [])    : 6 | ['age', 'fnlwgt', 'education-num', 'capital-gain', 'capital-loss', ...]
		('object', []) : 8 | ['workclass', 'education', 'marital-status', 'occupation', 'relationship', ...]
	Types of features in processed data (raw dtype, special dtypes):
		('category', [])  : 7 | ['workclass', 'education', 'marital-status', 'occupation', 'relationship', ...]
		('int', [])       : 6 | ['age', 'fnlwgt', 'education-num', 'capital-gain', 'capital-loss', ...]
		('int', ['bool']) : 1 | ['sex']
	0.1s = Fit runtime
	14 features in original data used to generate 14 features in processed data.
	Train Data (Processed) Memory Usage: 0.06 MB (0.0% of available memory)
Data preprocessing and feature engineering runtime = 0.08s ...
AutoGluon will gauge predictive performance using evaluation metric: 'accuracy'
	To change this, specify the eval_metric parameter of Predictor()
Automatically generating train/validation split with holdout_frac=0.2, Train Rows: 800, Val Rows: 200
User-specified model hyperparameters to be fit:
{
	'NN_TORCH': [{}],
	'GBM': [{'extra_trees': True, 'ag_args': {'name_suffix': 'XT'}}, {}, {'learning_rate': 0.03, 'num_leaves': 128, 'feature_fraction': 0.9, 'min_data_in_leaf': 3, 'ag_args': {'name_suffix': 'Large', 'priority': 0, 'hyperparameter_tune_kwargs': None}}],
	'CAT': [{}],
	'XGB': [{}],
	'FASTAI': [{}],
	'RF': [{'criterion': 'gini', 'ag_args': {'name_suffix': 'Gini', 'problem_types': ['binary', 'multiclass']}}, {'criterion': 'entropy', 'ag_args': {'name_suffix': 'Entr', 'problem_types': ['binary', 'multiclass']}}, {'criterion': 'squared_error', 'ag_args': {'name_suffix': 'MSE', 'problem_types': ['regression', 'quantile']}}],
	'XT': [{'criterion': 'gini', 'ag_args': {'name_suffix': 'Gini', 'problem_types': ['binary', 'multiclass']}}, {'criterion': 'entropy', 'ag_args': {'name_suffix': 'Entr', 'problem_types': ['binary', 'multiclass']}}, {'criterion': 'squared_error', 'ag_args': {'name_suffix': 'MSE', 'problem_types': ['regression', 'quantile']}}],
	'KNN': [{'weights': 'uniform', 'ag_args': {'name_suffix': 'Unif'}}, {'weights': 'distance', 'ag_args': {'name_suffix': 'Dist'}}],
	'<class '__main__.CustomRandomForestModel'>': [{'n_estimators': 300, 'n_jobs': -1, 'random_state': 0, 'max_depth': 26, 'max_features': 0.4459435365634299, 'criterion': 'entropy'}],
}
Custom Model Type Detected: <class '__main__.CustomRandomForestModel'>
Fitting 14 L1 models, fit_strategy="sequential" ...
Fitting model: KNeighborsUnif ...
	0.725	 = Validation score   (accuracy)
	0.03s	 = Training   runtime
	0.01s	 = Validation runtime
Fitting model: KNeighborsDist ...
	0.71	 = Validation score   (accuracy)
	0.01s	 = Training   runtime
	0.01s	 = Validation runtime
Fitting model: LightGBMXT ...
	0.85	 = Validation score   (accuracy)
	0.31s	 = Training   runtime
	0.0s	 = Validation runtime
Fitting model: LightGBM ...
	0.84	 = Validation score   (accuracy)
	0.29s	 = Training   runtime
	0.0s	 = Validation runtime
Fitting model: RandomForestGini ...
	0.84	 = Validation score   (accuracy)
	0.71s	 = Training   runtime
	0.06s	 = Validation runtime
Fitting model: RandomForestEntr ...
	0.835	 = Validation score   (accuracy)
	0.61s	 = Training   runtime
	0.06s	 = Validation runtime
Fitting model: CatBoost ...
	0.86	 = Validation score   (accuracy)
	1.87s	 = Training   runtime
	0.0s	 = Validation runtime
Fitting model: ExtraTreesGini ...
	0.815	 = Validation score   (accuracy)
	0.65s	 = Training   runtime
	0.06s	 = Validation runtime
Fitting model: ExtraTreesEntr ...
	0.82	 = Validation score   (accuracy)
	0.63s	 = Training   runtime
	0.06s	 = Validation runtime
Fitting model: NeuralNetFastAI ...
No improvement since epoch 7: early stopping
	0.84	 = Validation score   (accuracy)
	3.11s	 = Training   runtime
	0.01s	 = Validation runtime
Fitting model: XGBoost ...
	0.855	 = Validation score   (accuracy)
	0.44s	 = Training   runtime
	0.01s	 = Validation runtime
Fitting model: NeuralNetTorch ...
	0.855	 = Validation score   (accuracy)
	3.63s	 = Training   runtime
	0.01s	 = Validation runtime
Fitting model: LightGBMLarge ...
	0.795	 = Validation score   (accuracy)
	0.78s	 = Training   runtime
	0.0s	 = Validation runtime
Fitting model: CustomRandomForestModel ...
	0.855	 = Validation score   (accuracy)
	0.55s	 = Training   runtime
	0.06s	 = Validation runtime
Fitting model: WeightedEnsemble_L2 ...
	Ensemble Weights: {'LightGBMXT': 0.267, 'CustomRandomForestModel': 0.267, 'CatBoost': 0.2, 'RandomForestGini': 0.133, 'ExtraTreesEntr': 0.133}
	0.885	 = Validation score   (accuracy)
	0.1s	 = Training   runtime
	0.0s	 = Validation runtime
AutoGluon training complete, total runtime = 14.36s ... Best model: WeightedEnsemble_L2 | Estimated inference throughput: 1094.9 rows/s (200 batch size)
Disabling decision threshold calibration for metric `accuracy` due to having fewer than 10000 rows of validation data for calibration, to avoid overfitting (200 rows).
	`accuracy` is generally not improved through threshold calibration. Force calibration via specifying `calibrate_decision_threshold=True`.
TabularPredictor saved. To load, use: predictor = TabularPredictor.load("/home/ci/autogluon/docs/tutorials/tabular/advanced/AutogluonModels/ag-20250508_205630")
model score_test score_val eval_metric pred_time_test pred_time_val fit_time pred_time_test_marginal pred_time_val_marginal fit_time_marginal stack_level can_infer fit_order
0 CatBoost 0.852902 0.860 accuracy 0.012000 0.004439 1.870484 0.012000 0.004439 1.870484 1 True 7
1 WeightedEnsemble_L2 0.850957 0.885 accuracy 0.396561 0.182661 4.161953 0.003787 0.000768 0.095763 2 True 15
2 LightGBMXT 0.850752 0.850 accuracy 0.022345 0.003736 0.307004 0.022345 0.003736 0.307004 1 True 3
3 NeuralNetFastAI 0.848193 0.840 accuracy 0.138275 0.008939 3.112128 0.138275 0.008939 3.112128 1 True 10
4 XGBoost 0.846658 0.855 accuracy 0.048345 0.005962 0.435826 0.048345 0.005962 0.435826 1 True 11
5 LightGBM 0.841335 0.840 accuracy 0.016278 0.003465 0.287818 0.016278 0.003465 0.287818 1 True 4
6 RandomForestGini 0.840004 0.840 accuracy 0.125455 0.058398 0.714421 0.125455 0.058398 0.714421 1 True 5
7 RandomForestEntr 0.837240 0.835 accuracy 0.109810 0.057897 0.614643 0.109810 0.057897 0.614643 1 True 6
8 CustomRandomForestModel 0.834988 0.855 accuracy 0.118730 0.057171 0.546502 0.118730 0.057171 0.546502 1 True 14
9 NeuralNetTorch 0.833248 0.855 accuracy 0.058589 0.010674 3.629005 0.058589 0.010674 3.629005 1 True 12
10 ExtraTreesGini 0.831917 0.815 accuracy 0.110340 0.057254 0.648913 0.110340 0.057254 0.648913 1 True 8
11 LightGBMLarge 0.829461 0.795 accuracy 0.066144 0.004574 0.783208 0.066144 0.004574 0.783208 1 True 13
12 ExtraTreesEntr 0.829358 0.820 accuracy 0.114245 0.058149 0.627778 0.114245 0.058149 0.627778 1 True 9
13 KNeighborsUnif 0.744600 0.725 accuracy 0.025678 0.014041 0.031765 0.025678 0.014041 0.031765 1 True 1
14 KNeighborsDist 0.710922 0.710 accuracy 0.025847 0.014201 0.011127 0.025847 0.014201 0.011127 1 True 2

总结

将自定义模型添加到 AutoGluon 就是这么简单。如果您创建了自定义模型,请考虑提交一个 PR,以便我们可以将其正式添加到 AutoGluon 中!

有关更多教程,请参阅预测表格中的列 - 快速入门预测表格中的列 - 深入了解

有关高级自定义模型的教程,请参阅将自定义模型添加到 AutoGluon(进阶)