将自定义模型添加到 AutoGluon¶
提示: 如果您刚开始使用 AutoGluon,请查看预测表格中的列 - 快速入门以了解 AutoGluon API 的基础知识。
本教程介绍如何将自定义模型添加到 AutoGluon,以便与默认模型(默认模型文档)一起进行训练、超参数调优和集成。
在此示例中,我们创建一个自定义的随机森林模型,用于 AutoGluon。AutoGluon 中的所有模型都继承自 AbstractModel 类(AbstractModel 源代码),并且必须遵循其 API 才能与其他模型协同工作。
请注意,虽然本教程提供了一个基本的模型实现,但并未涵盖大多数已实现模型中使用的许多方面。
要更好地理解如何实现更高级的功能,请参考以下模型的源代码:
功能 |
参考实现 |
---|---|
遵守时间限制 / 提前停止逻辑 |
|
遵守内存使用限制 |
LGBModel 和 RFModel |
支持样本权重 |
LGBModel |
验证数据和 eval_metric 用法 |
LGBModel |
支持 GPU 训练 |
LGBModel |
非序列化模型的保存 / 加载逻辑 |
|
支持高级问题类型 (Softclass, Quantile) |
RFModel |
支持文本特征类型 |
|
支持图像特征类型 |
|
包依赖的延迟导入 |
LGBModel |
自定义 HPO 逻辑 |
LGBModel |
实现自定义模型¶
在这里,我们定义将在本教程其余部分使用的自定义模型。
必须实现的最重要方法是 _fit
和 _preprocess
。
要与 AutoGluon 官方的随机森林实现进行比较,请参阅 RFModel 源代码。
请结合代码注释来更好地理解代码的工作原理。
import numpy as np
import pandas as pd
from autogluon.core.models import AbstractModel
from autogluon.features.generators import LabelEncoderFeatureGenerator
class CustomRandomForestModel(AbstractModel):
def __init__(self, **kwargs):
# Simply pass along kwargs to parent, and init our internal `_feature_generator` variable to None
super().__init__(**kwargs)
self._feature_generator = None
# The `_preprocess` method takes the input data and transforms it to the internal representation usable by the model.
# `_preprocess` is called by `preprocess` and is used during model fit and model inference.
def _preprocess(self, X: pd.DataFrame, is_train=False, **kwargs) -> np.ndarray:
print(f'Entering the `_preprocess` method: {len(X)} rows of data (is_train={is_train})')
X = super()._preprocess(X, **kwargs)
if is_train:
# X will be the training data.
self._feature_generator = LabelEncoderFeatureGenerator(verbosity=0)
self._feature_generator.fit(X=X)
if self._feature_generator.features_in:
# This converts categorical features to numeric via stateful label encoding.
X = X.copy()
X[self._feature_generator.features_in] = self._feature_generator.transform(X=X)
# Add a fillna call to handle missing values.
# Some algorithms will be able to handle NaN values internally (LightGBM).
# In those cases, you can simply pass the NaN values into the inner model.
# Finally, convert to numpy for optimized memory usage and because sklearn RF works with raw numpy input.
return X.fillna(0).to_numpy(dtype=np.float32)
# The `_fit` method takes the input training data (and optionally the validation data) and trains the model.
def _fit(self,
X: pd.DataFrame, # training data
y: pd.Series, # training labels
# X_val=None, # val data (unused in RF model)
# y_val=None, # val labels (unused in RF model)
# time_limit=None, # time limit in seconds (ignored in tutorial)
**kwargs): # kwargs includes many other potential inputs, refer to AbstractModel documentation for details
print('Entering the `_fit` method')
# First we import the required dependencies for the model. Note that we do not import them outside of the method.
# This enables AutoGluon to be highly extensible and modular.
# For an example of best practices when importing model dependencies, refer to LGBModel.
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
# Valid self.problem_type values include ['binary', 'multiclass', 'regression', 'quantile', 'softclass']
if self.problem_type in ['regression', 'softclass']:
model_cls = RandomForestRegressor
else:
model_cls = RandomForestClassifier
# Make sure to call preprocess on X near the start of `_fit`.
# This is necessary because the data is converted via preprocess during predict, and needs to be in the same format as during fit.
X = self.preprocess(X, is_train=True)
# This fetches the user-specified (and default) hyperparameters for the model.
params = self._get_model_params()
print(f'Hyperparameters: {params}')
# self.model should be set to the trained inner model, so that internally during predict we can call `self.model.predict(...)`
self.model = model_cls(**params)
self.model.fit(X, y)
print('Exiting the `_fit` method')
# The `_set_default_params` method defines the default hyperparameters of the model.
# User-specified parameters will override these values on a key-by-key basis.
def _set_default_params(self):
default_params = {
'n_estimators': 300,
'n_jobs': -1,
'random_state': 0,
}
for param, val in default_params.items():
self._set_default_param_value(param, val)
# The `_get_default_auxiliary_params` method defines various model-agnostic parameters such as maximum memory usage and valid input column dtypes.
# For most users who build custom models, they will only need to specify the valid/invalid dtypes to the model here.
def _get_default_auxiliary_params(self) -> dict:
default_auxiliary_params = super()._get_default_auxiliary_params()
extra_auxiliary_params = dict(
# the total set of raw dtypes are: ['int', 'float', 'category', 'object', 'datetime']
# object feature dtypes include raw text and image paths, which should only be handled by specialized models
# datetime raw dtypes are generally converted to int in upstream pre-processing,
# so models generally shouldn't need to explicitly support datetime dtypes.
valid_raw_types=['int', 'float', 'category'],
# Other options include `valid_special_types`, `ignored_type_group_raw`, and `ignored_type_group_special`.
# Refer to AbstractModel for more details on available options.
)
default_auxiliary_params.update(extra_auxiliary_params)
return default_auxiliary_params
加载数据¶
接下来我们将加载数据。在本教程中,我们将使用成人收入数据集,因为它包含整数、浮点数和分类特征的混合。
from autogluon.tabular import TabularDataset
train_data = TabularDataset('https://autogluon.s3.amazonaws.com/datasets/Inc/train.csv') # can be local CSV file as well, returns Pandas DataFrame
test_data = TabularDataset('https://autogluon.s3.amazonaws.com/datasets/Inc/test.csv') # another Pandas DataFrame
label = 'class' # specifies which column do we want to predict
train_data = train_data.sample(n=1000, random_state=0) # subsample for faster demo
train_data.head(5)
age | workclass | fnlwgt | education | education-num | marital-status | occupation | relationship | race | sex | capital-gain | capital-loss | hours-per-week | native-country | class | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
6118 | 51 | Private | 39264 | Some-college | 10 | Married-civ-spouse | Exec-managerial | Wife | White | Female | 0 | 0 | 40 | United-States | >50K |
23204 | 58 | Private | 51662 | 10th | 6 | Married-civ-spouse | Other-service | Wife | White | Female | 0 | 0 | 8 | United-States | <=50K |
29590 | 40 | Private | 326310 | Some-college | 10 | Married-civ-spouse | Craft-repair | Husband | White | Male | 0 | 0 | 44 | United-States | <=50K |
18116 | 37 | Private | 222450 | HS-grad | 9 | Never-married | Sales | Not-in-family | White | Male | 0 | 2339 | 40 | El-Salvador | <=50K |
33964 | 62 | Private | 109190 | Bachelors | 13 | Married-civ-spouse | Exec-managerial | Husband | White | Male | 15024 | 0 | 40 | United-States | >50K |
在没有 TabularPredictor 的情况下训练自定义模型¶
下面我们将演示如何在不使用 TabularPredictor 的情况下训练模型。这对于调试以及在实现模型时最小化需要理解的代码量非常有用。
这个过程类似于调用 TabularPredictor
的 fit 方法时内部发生的事情,但更加简化和精简。
如果数据已经被清理过(全部是数字),那么我们可以直接使用数据调用 fit 方法,但成人数据集不是这样。
清理标签¶
使输入数据对模型有效的第一步是清理标签。
目前,它们是字符串,但对于二分类,我们需要将它们转换为数值(0 和 1)。
幸运的是,AutoGluon 已经实现了逻辑,既可以检测这是二分类问题(通过 infer_problem_type
),也可以将标签映射到 0 和 1 的转换器(LabelCleaner
)
# Separate features and labels
X = train_data.drop(columns=[label])
y = train_data[label]
X_test = test_data.drop(columns=[label])
y_test = test_data[label]
from autogluon.core.data import LabelCleaner
from autogluon.core.utils import infer_problem_type
# Construct a LabelCleaner to neatly convert labels to float/integers during model training/inference, can also use to inverse_transform back to original.
problem_type = infer_problem_type(y=y) # Infer problem type (or else specify directly)
label_cleaner = LabelCleaner.construct(problem_type=problem_type, y=y)
y_clean = label_cleaner.transform(y)
print(f'Labels cleaned: {label_cleaner.inv_map}')
print(f'inferred problem type as: {problem_type}')
print('Cleaned label values:')
y_clean.head(5)
Labels cleaned: {' <=50K': 0, ' >50K': 1}
inferred problem type as: binary
Cleaned label values:
6118 1
23204 0
29590 0
18116 0
33964 1
Name: class, dtype: uint8
清理特征¶
接下来,我们需要清理特征。目前,像“workclass”这样的特征是对象 dtype(字符串),但我们实际上希望将它们用作分类特征。大多数模型不接受字符串输入,所以我们需要将字符串转换为数字。
AutoGluon 包含一个专门用于清理、转换和生成特征的整个模块,称为 autogluon.features。在这里,我们将使用 TabularPredictor
内部使用的相同特征生成器,将对象 dtype 转换为分类类型并最小化内存使用。
from autogluon.common.utils.log_utils import set_logger_verbosity
from autogluon.features.generators import AutoMLPipelineFeatureGenerator
set_logger_verbosity(2) # Set logger so more detailed logging is shown for tutorial
feature_generator = AutoMLPipelineFeatureGenerator()
X_clean = feature_generator.fit_transform(X)
X_clean.head(5)
Fitting AutoMLPipelineFeatureGenerator...
Available Memory: 29463.95 MB
Train Data (Original) Memory Usage: 0.57 MB (0.0% of available memory)
Inferring data type of each feature based on column values. Set feature_metadata_in to manually specify special dtypes of the features.
Stage 1 Generators:
Fitting AsTypeFeatureGenerator...
Note: Converting 1 features to boolean dtype as they only contain 2 unique values.
Stage 2 Generators:
Fitting FillNaFeatureGenerator...
Stage 3 Generators:
Fitting IdentityFeatureGenerator...
Fitting CategoryFeatureGenerator...
Fitting CategoryMemoryMinimizeFeatureGenerator...
Stage 4 Generators:
Fitting DropUniqueFeatureGenerator...
Stage 5 Generators:
Fitting DropDuplicatesFeatureGenerator...
Types of features in original data (raw dtype, special dtypes):
('int', []) : 6 | ['age', 'fnlwgt', 'education-num', 'capital-gain', 'capital-loss', ...]
('object', []) : 8 | ['workclass', 'education', 'marital-status', 'occupation', 'relationship', ...]
Types of features in processed data (raw dtype, special dtypes):
('category', []) : 7 | ['workclass', 'education', 'marital-status', 'occupation', 'relationship', ...]
('int', []) : 6 | ['age', 'fnlwgt', 'education-num', 'capital-gain', 'capital-loss', ...]
('int', ['bool']) : 1 | ['sex']
0.1s = Fit runtime
14 features in original data used to generate 14 features in processed data.
Train Data (Processed) Memory Usage: 0.06 MB (0.0% of available memory)
age | fnlwgt | education-num | sex | capital-gain | capital-loss | hours-per-week | workclass | education | marital-status | occupation | relationship | race | native-country | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
6118 | 51 | 39264 | 10 | 0 | 0 | 0 | 40 | 3 | 14 | 1 | 4 | 5 | 4 | 24 |
23204 | 58 | 51662 | 6 | 0 | 0 | 0 | 8 | 3 | 0 | 1 | 8 | 5 | 4 | 24 |
29590 | 40 | 326310 | 10 | 1 | 0 | 0 | 44 | 3 | 14 | 1 | 3 | 0 | 4 | 24 |
18116 | 37 | 222450 | 9 | 1 | 0 | 2339 | 40 | 3 | 11 | 3 | 12 | 1 | 4 | 6 |
33964 | 62 | 109190 | 13 | 1 | 15024 | 0 | 40 | 3 | 9 | 1 | 4 | 0 | 4 | 24 |
AutoMLPipelineFeatureGenerator 不会为数值特征填充缺失值,也不会对数值特征进行重新缩放或对分类特征进行 one-hot 编码。如果模型需要这些操作,您需要在 _preprocess
方法中添加这些操作,并且可能会发现某些 FeatureGenerator 类对此有用。
拟合模型¶
我们现在可以使用清理后的特征和标签来拟合模型了。
custom_model = CustomRandomForestModel()
# We could also specify hyperparameters to override defaults
# custom_model = CustomRandomForestModel(hyperparameters={'max_depth': 10})
custom_model.fit(X=X_clean, y=y_clean) # Fit custom model
# To save to disk and load the model, do the following:
# load_path = custom_model.path
# custom_model.save()
# del custom_model
# custom_model = CustomRandomForestModel.load(path=load_path)
Entering the `_fit` method
Entering the `_preprocess` method: 1000 rows of data (is_train=True)
Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0}
Exiting the `_fit` method
Warning: No name was specified for model, defaulting to class name: CustomRandomForestModel
No path specified. Models will be saved in: "AutogluonModels/ag-20250508_205558/CustomRandomForestModel"
Warning: No path was specified for model, defaulting to: /home/ci/autogluon/docs/tutorials/tabular/advanced/AutogluonModels/ag-20250508_205558/
Selected class <--> label mapping: class 1 = 1, class 0 = 0
Model CustomRandomForestModel's eval_metric inferred to be 'accuracy' because problem_type='binary' and eval_metric was not specified during init.
<__main__.CustomRandomForestModel at 0x7f2ae4aa7990>
使用训练好的模型进行预测¶
现在模型已经拟合好了,我们可以对新数据进行预测。请记住,我们需要对新数据进行与训练数据相同的 数据和标签转换。
# Prepare test data
X_test_clean = feature_generator.transform(X_test)
y_test_clean = label_cleaner.transform(y_test)
X_test.head(5)
age | workclass | fnlwgt | education | education-num | marital-status | occupation | relationship | race | sex | capital-gain | capital-loss | hours-per-week | native-country | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 31 | Private | 169085 | 11th | 7 | Married-civ-spouse | Sales | Wife | White | Female | 0 | 0 | 20 | United-States |
1 | 17 | Self-emp-not-inc | 226203 | 12th | 8 | Never-married | Sales | Own-child | White | Male | 0 | 0 | 45 | United-States |
2 | 47 | Private | 54260 | Assoc-voc | 11 | Married-civ-spouse | Exec-managerial | Husband | White | Male | 0 | 1887 | 60 | United-States |
3 | 21 | Private | 176262 | Some-college | 10 | Never-married | Exec-managerial | Own-child | White | Female | 0 | 0 | 30 | United-States |
4 | 17 | Private | 241185 | 12th | 8 | Never-married | Prof-specialty | Own-child | White | Male | 0 | 0 | 20 | United-States |
获取测试数据的原始预测结果
y_pred = custom_model.predict(X_test_clean)
print(y_pred[:5])
Entering the `_preprocess` method: 9769 rows of data (is_train=False)
[0 0 1 0 0]
请注意,这些预测是正类的(即被推断为 1 的类别)。要获得更易于解释的结果,请执行以下操作:
y_pred_orig = label_cleaner.inverse_transform(y_pred)
y_pred_orig.head(5)
0 <=50K
1 <=50K
2 >50K
3 <=50K
4 <=50K
dtype: object
使用训练好的模型评分¶
默认情况下,模型具有特定于 problem_type 的 eval_metric。对于二分类,它使用准确率。
我们可以通过执行以下操作来获取模型的准确率得分:
score = custom_model.score(X_test_clean, y_test_clean)
print(f'Test score ({custom_model.eval_metric.name}) = {score}')
Entering the `_preprocess` method: 9769 rows of data (is_train=False)
Test score (accuracy) = 0.8424608455317842
在没有 TabularPredictor 的情况下训练一个 Bagging 自定义模型¶
AutoGluon 中一些更高级的功能,例如 Bagging,对于继承自 AbstractModel 的模型来说可以很容易实现。
您甚至可以用几行代码对您的自定义模型进行 Bagging。这是在几乎任何模型上获得质量提升的快速方法
from autogluon.core.models import BaggedEnsembleModel
bagged_custom_model = BaggedEnsembleModel(CustomRandomForestModel())
# Parallel folding currently doesn't work with a class not defined in a separate module because of underlying pickle serialization issue
# You don't need this following line if you put your custom model in a separate file and import it.
bagged_custom_model.params['fold_fitting_strategy'] = 'sequential_local'
bagged_custom_model.fit(X=X_clean, y=y_clean, k_fold=10) # Perform 10-fold bagging
bagged_score = bagged_custom_model.score(X_test_clean, y_test_clean)
print(f'Test score ({bagged_custom_model.eval_metric.name}) = {bagged_score} (bagged)')
print(f'Bagging increased model accuracy by {round(bagged_score - score, 4) * 100}%!')
Entering the `_fit` method
Entering the `_preprocess` method: 900 rows of data (is_train=True)
Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0}
Exiting the `_fit` method
Entering the `_preprocess` method: 100 rows of data (is_train=False)
Entering the `_fit` method
Entering the `_preprocess` method: 900 rows of data (is_train=True)
Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0}
Exiting the `_fit` method
Entering the `_preprocess` method: 100 rows of data (is_train=False)
Entering the `_fit` method
Entering the `_preprocess` method: 900 rows of data (is_train=True)
Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0}
Exiting the `_fit` method
Entering the `_preprocess` method: 100 rows of data (is_train=False)
Entering the `_fit` method
Entering the `_preprocess` method: 900 rows of data (is_train=True)
Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0}
Exiting the `_fit` method
Entering the `_preprocess` method: 100 rows of data (is_train=False)
Entering the `_fit` method
Entering the `_preprocess` method: 900 rows of data (is_train=True)
Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0}
Exiting the `_fit` method
Entering the `_preprocess` method: 100 rows of data (is_train=False)
Entering the `_fit` method
Entering the `_preprocess` method: 900 rows of data (is_train=True)
Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0}
Exiting the `_fit` method
Entering the `_preprocess` method: 100 rows of data (is_train=False)
Entering the `_fit` method
Entering the `_preprocess` method: 900 rows of data (is_train=True)
Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0}
Exiting the `_fit` method
Entering the `_preprocess` method: 100 rows of data (is_train=False)
Entering the `_fit` method
Entering the `_preprocess` method: 900 rows of data (is_train=True)
Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0}
Exiting the `_fit` method
Entering the `_preprocess` method: 100 rows of data (is_train=False)
Entering the `_fit` method
Entering the `_preprocess` method: 900 rows of data (is_train=True)
Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0}
Exiting the `_fit` method
Entering the `_preprocess` method: 100 rows of data (is_train=False)
Entering the `_fit` method
Entering the `_preprocess` method: 900 rows of data (is_train=True)
Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0}
Exiting the `_fit` method
Entering the `_preprocess` method: 100 rows of data (is_train=False)
Entering the `_preprocess` method: 9769 rows of data (is_train=False)
Entering the `_preprocess` method: 9769 rows of data (is_train=False)
Entering the `_preprocess` method: 9769 rows of data (is_train=False)
Entering the `_preprocess` method: 9769 rows of data (is_train=False)
Entering the `_preprocess` method: 9769 rows of data (is_train=False)
Entering the `_preprocess` method: 9769 rows of data (is_train=False)
Entering the `_preprocess` method: 9769 rows of data (is_train=False)
Entering the `_preprocess` method: 9769 rows of data (is_train=False)
Entering the `_preprocess` method: 9769 rows of data (is_train=False)
Entering the `_preprocess` method: 9769 rows of data (is_train=False)
Test score (accuracy) = 0.8435868563824342 (bagged)
Bagging increased model accuracy by 0.11%!
Warning: No name was specified for model, defaulting to class name: CustomRandomForestModel
No path specified. Models will be saved in: "AutogluonModels/ag-20250508_205558-001/CustomRandomForestModel"
Warning: No path was specified for model, defaulting to: /home/ci/autogluon/docs/tutorials/tabular/advanced/AutogluonModels/ag-20250508_205558-001/
Warning: No name was specified for model, defaulting to class name: BaggedEnsembleModel
No path specified. Models will be saved in: "AutogluonModels/ag-20250508_205558/BaggedEnsembleModel"
Warning: No path was specified for model, defaulting to: /home/ci/autogluon/docs/tutorials/tabular/advanced/AutogluonModels/ag-20250508_205558/
Selected class <--> label mapping: class 1 = 1, class 0 = 0
Selected class <--> label mapping: class 1 = 1, class 0 = 0
Model CustomRandomForestModel's eval_metric inferred to be 'accuracy' because problem_type='binary' and eval_metric was not specified during init.
Model 's eval_metric inferred to be 'accuracy' because problem_type='binary' and eval_metric was not specified during init.
Fitting 10 child models (S1F1 - S1F10) | Fitting with SequentialLocalFoldFittingStrategy
Model S1F1's eval_metric inferred to be 'accuracy' because problem_type='binary' and eval_metric was not specified during init.
Model S1F2's eval_metric inferred to be 'accuracy' because problem_type='binary' and eval_metric was not specified during init.
Model S1F3's eval_metric inferred to be 'accuracy' because problem_type='binary' and eval_metric was not specified during init.
Model S1F4's eval_metric inferred to be 'accuracy' because problem_type='binary' and eval_metric was not specified during init.
Model S1F5's eval_metric inferred to be 'accuracy' because problem_type='binary' and eval_metric was not specified during init.
Model S1F6's eval_metric inferred to be 'accuracy' because problem_type='binary' and eval_metric was not specified during init.
Model S1F7's eval_metric inferred to be 'accuracy' because problem_type='binary' and eval_metric was not specified during init.
Model S1F8's eval_metric inferred to be 'accuracy' because problem_type='binary' and eval_metric was not specified during init.
Model S1F9's eval_metric inferred to be 'accuracy' because problem_type='binary' and eval_metric was not specified during init.
Model S1F10's eval_metric inferred to be 'accuracy' because problem_type='binary' and eval_metric was not specified during init.
请注意,bagging 模型在训练数据的不同分割上训练了 10 个 CustomRandomForestModels。进行预测时,bagging 模型会平均这 10 个模型的预测结果。
使用 TabularPredictor 训练自定义模型¶
虽然不使用 TabularPredictor 可以简化我们在开发和调试模型时需要关注的代码量,但最终我们希望利用 TabularPredictor 来最大化模型的潜力。
使用 TabularPredictor 训练原始数据的代码非常简单。无需指定 LabelCleaner、FeatureGenerator 或验证集,所有这些都在内部处理。
在这里,我们使用不同的超参数训练 3 个 CustomRandomForestModel。
from autogluon.tabular import TabularPredictor
# custom_hyperparameters = {CustomRandomForestModel: {}} # train 1 CustomRandomForestModel Model with default hyperparameters
custom_hyperparameters = {CustomRandomForestModel: [{}, {'max_depth': 10}, {'max_features': 0.9, 'max_depth': 20}]} # Train 3 CustomRandomForestModel with different hyperparameters
predictor = TabularPredictor(label=label).fit(train_data, hyperparameters=custom_hyperparameters)
Entering the `_fit` method
Entering the `_preprocess` method: 800 rows of data (is_train=True)
Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0}
Exiting the `_fit` method
Entering the `_preprocess` method: 200 rows of data (is_train=False)
Entering the `_fit` method
Entering the `_preprocess` method: 800 rows of data (is_train=True)
Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0, 'max_depth': 10}
Exiting the `_fit` method
Entering the `_preprocess` method: 200 rows of data (is_train=False)
Entering the `_fit` method
Entering the `_preprocess` method: 800 rows of data (is_train=True)
Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0, 'max_features': 0.9, 'max_depth': 20}
Exiting the `_fit` method
Entering the `_preprocess` method: 200 rows of data (is_train=False)
No path specified. Models will be saved in: "AutogluonModels/ag-20250508_205605"
Verbosity: 2 (Standard Logging)
=================== System Info ===================
AutoGluon Version: 1.3.1b20250508
Python Version: 3.11.9
Operating System: Linux
Platform Machine: x86_64
Platform Version: #1 SMP Wed Mar 12 14:53:59 UTC 2025
CPU Count: 8
Memory Avail: 28.73 GB / 30.95 GB (92.8%)
Disk Space Avail: 212.07 GB / 255.99 GB (82.8%)
===================================================
No presets specified! To achieve strong results with AutoGluon, it is recommended to use the available presets. Defaulting to `'medium'`...
Recommended Presets (For more details refer to https://autogluon.cn/stable/tutorials/tabular/tabular-essentials.html#presets):
presets='experimental' : New in v1.2: Pre-trained foundation model + parallel fits. The absolute best accuracy without consideration for inference speed. Does not support GPU.
presets='best' : Maximize accuracy. Recommended for most users. Use in competitions and benchmarks.
presets='high' : Strong accuracy with fast inference speed.
presets='good' : Good accuracy with very fast inference speed.
presets='medium' : Fast training time, ideal for initial prototyping.
Beginning AutoGluon training ...
AutoGluon will save models to "/home/ci/autogluon/docs/tutorials/tabular/advanced/AutogluonModels/ag-20250508_205605"
Train Data Rows: 1000
Train Data Columns: 14
Label Column: class
AutoGluon infers your prediction problem is: 'binary' (because only two unique label-values observed).
2 unique label values: [' >50K', ' <=50K']
If 'binary' is not the correct problem_type, please manually specify the problem_type parameter during Predictor init (You may specify problem_type as one of: ['binary', 'multiclass', 'regression', 'quantile'])
Problem Type: binary
Preprocessing data ...
Selected class <--> label mapping: class 1 = >50K, class 0 = <=50K
Note: For your binary classification, AutoGluon arbitrarily selected which label-value represents positive ( >50K) vs negative ( <=50K) class.
To explicitly set the positive_class, either rename classes to 1 and 0, or specify positive_class in Predictor init.
Using Feature Generators to preprocess the data ...
Fitting AutoMLPipelineFeatureGenerator...
Available Memory: 29415.33 MB
Train Data (Original) Memory Usage: 0.56 MB (0.0% of available memory)
Inferring data type of each feature based on column values. Set feature_metadata_in to manually specify special dtypes of the features.
Stage 1 Generators:
Fitting AsTypeFeatureGenerator...
Note: Converting 1 features to boolean dtype as they only contain 2 unique values.
Stage 2 Generators:
Fitting FillNaFeatureGenerator...
Stage 3 Generators:
Fitting IdentityFeatureGenerator...
Fitting CategoryFeatureGenerator...
Fitting CategoryMemoryMinimizeFeatureGenerator...
Stage 4 Generators:
Fitting DropUniqueFeatureGenerator...
Stage 5 Generators:
Fitting DropDuplicatesFeatureGenerator...
Types of features in original data (raw dtype, special dtypes):
('int', []) : 6 | ['age', 'fnlwgt', 'education-num', 'capital-gain', 'capital-loss', ...]
('object', []) : 8 | ['workclass', 'education', 'marital-status', 'occupation', 'relationship', ...]
Types of features in processed data (raw dtype, special dtypes):
('category', []) : 7 | ['workclass', 'education', 'marital-status', 'occupation', 'relationship', ...]
('int', []) : 6 | ['age', 'fnlwgt', 'education-num', 'capital-gain', 'capital-loss', ...]
('int', ['bool']) : 1 | ['sex']
0.1s = Fit runtime
14 features in original data used to generate 14 features in processed data.
Train Data (Processed) Memory Usage: 0.06 MB (0.0% of available memory)
Data preprocessing and feature engineering runtime = 0.09s ...
AutoGluon will gauge predictive performance using evaluation metric: 'accuracy'
To change this, specify the eval_metric parameter of Predictor()
Automatically generating train/validation split with holdout_frac=0.2, Train Rows: 800, Val Rows: 200
User-specified model hyperparameters to be fit:
{
'<class '__main__.CustomRandomForestModel'>': [{}, {'max_depth': 10}, {'max_features': 0.9, 'max_depth': 20}],
}
Custom Model Type Detected: <class '__main__.CustomRandomForestModel'>
Custom Model Type Detected: <class '__main__.CustomRandomForestModel'>
Custom Model Type Detected: <class '__main__.CustomRandomForestModel'>
Fitting 3 L1 models, fit_strategy="sequential" ...
Fitting model: CustomRandomForestModel ...
0.835 = Validation score (accuracy)
0.55s = Training runtime
0.06s = Validation runtime
Fitting model: CustomRandomForestModel_2 ...
0.845 = Validation score (accuracy)
0.52s = Training runtime
0.06s = Validation runtime
Fitting model: CustomRandomForestModel_3 ...
0.84 = Validation score (accuracy)
0.54s = Training runtime
0.06s = Validation runtime
Fitting model: WeightedEnsemble_L2 ...
Ensemble Weights: {'CustomRandomForestModel_2': 1.0}
0.845 = Validation score (accuracy)
0.0s = Training runtime
0.0s = Validation runtime
AutoGluon training complete, total runtime = 1.96s ... Best model: WeightedEnsemble_L2 | Estimated inference throughput: 3342.2 rows/s (200 batch size)
Disabling decision threshold calibration for metric `accuracy` due to having fewer than 10000 rows of validation data for calibration, to avoid overfitting (200 rows).
`accuracy` is generally not improved through threshold calibration. Force calibration via specifying `calibrate_decision_threshold=True`.
TabularPredictor saved. To load, use: predictor = TabularPredictor.load("/home/ci/autogluon/docs/tutorials/tabular/advanced/AutogluonModels/ag-20250508_205605")
预测器排行榜¶
这里显示了训练的每个模型的统计信息。请注意,还训练了一个 WeightedEnsemble 模型。该模型尝试通过集成来组合其他模型的预测结果,以获得更好的验证分数。
predictor.leaderboard(test_data)
Entering the `_preprocess` method: 9769 rows of data (is_train=False)
Entering the `_preprocess` method: 9769 rows of data (is_train=False)
Entering the `_preprocess` method: 9769 rows of data (is_train=False)
model | score_test | score_val | eval_metric | pred_time_test | pred_time_val | fit_time | pred_time_test_marginal | pred_time_val_marginal | fit_time_marginal | stack_level | can_infer | fit_order | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | CustomRandomForestModel_2 | 0.846044 | 0.845 | accuracy | 0.097788 | 0.059052 | 0.519446 | 0.097788 | 0.059052 | 0.519446 | 1 | True | 2 |
1 | WeightedEnsemble_L2 | 0.846044 | 0.845 | accuracy | 0.100083 | 0.059841 | 0.522067 | 0.002295 | 0.000789 | 0.002620 | 2 | True | 4 |
2 | CustomRandomForestModel | 0.840414 | 0.835 | accuracy | 0.108968 | 0.058160 | 0.551860 | 0.108968 | 0.058160 | 0.551860 | 1 | True | 1 |
3 | CustomRandomForestModel_3 | 0.828846 | 0.840 | accuracy | 0.098568 | 0.059306 | 0.544473 | 0.098568 | 0.059306 | 0.544473 | 1 | True | 3 |
使用拟合的预测器进行预测¶
这里我们使用拟合好的预测器进行预测。这将自动使用最佳模型(score_val 最高的那一个)进行预测。
y_pred = predictor.predict(test_data)
# y_pred = predictor.predict(test_data, model='CustomRandomForestModel_3') # If we want a specific model to predict
y_pred.head(5)
Entering the `_preprocess` method: 9769 rows of data (is_train=False)
0 <=50K
1 <=50K
2 >50K
3 <=50K
4 <=50K
Name: class, dtype: object
使用 TabularPredictor 对自定义模型进行超参数调优¶
我们可以通过指定超参数搜索空间来代替精确值,从而轻松地对自定义模型进行超参数调优。
这里我们将自定义模型进行超参数调优 20 秒
from autogluon.common import space
custom_hyperparameters_hpo = {CustomRandomForestModel: {
'max_depth': space.Int(lower=5, upper=30),
'max_features': space.Real(lower=0.1, upper=1.0),
'criterion': space.Categorical('gini', 'entropy'),
}}
# Hyperparameter tune CustomRandomForestModel for 20 seconds
predictor = TabularPredictor(label=label).fit(train_data,
hyperparameters=custom_hyperparameters_hpo,
hyperparameter_tune_kwargs='auto', # enables HPO
time_limit=20)
Entering the `_fit` method
Entering the `_preprocess` method: 800 rows of data (is_train=True)
Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0, 'max_depth': 5, 'max_features': 0.1, 'criterion': 'gini'}
Exiting the `_fit` method
Entering the `_preprocess` method: 200 rows of data (is_train=False)
Entering the `_fit` method
Entering the `_preprocess` method: 800 rows of data (is_train=True)
Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0, 'max_depth': 20, 'max_features': 0.7436704297351775, 'criterion': 'gini'}
Exiting the `_fit` method
Entering the `_preprocess` method: 200 rows of data (is_train=False)
Entering the `_fit` method
Entering the `_preprocess` method: 800 rows of data (is_train=True)
Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0, 'max_depth': 8, 'max_features': 0.8625265649057129, 'criterion': 'entropy'}
Exiting the `_fit` method
Entering the `_preprocess` method: 200 rows of data (is_train=False)
Entering the `_fit` method
Entering the `_preprocess` method: 800 rows of data (is_train=True)
Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0, 'max_depth': 26, 'max_features': 0.4459435365634299, 'criterion': 'entropy'}
Exiting the `_fit` method
Entering the `_preprocess` method: 200 rows of data (is_train=False)
Entering the `_fit` method
Entering the `_preprocess` method: 800 rows of data (is_train=True)
Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0, 'max_depth': 11, 'max_features': 0.15104167958569886, 'criterion': 'entropy'}
Exiting the `_fit` method
Entering the `_preprocess` method: 200 rows of data (is_train=False)
Entering the `_fit` method
Entering the `_preprocess` method: 800 rows of data (is_train=True)
Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0, 'max_depth': 6, 'max_features': 0.8125525342743981, 'criterion': 'gini'}
Exiting the `_fit` method
Entering the `_preprocess` method: 200 rows of data (is_train=False)
Entering the `_fit` method
Entering the `_preprocess` method: 800 rows of data (is_train=True)
Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0, 'max_depth': 19, 'max_features': 0.6112401049845391, 'criterion': 'entropy'}
Exiting the `_fit` method
Entering the `_preprocess` method: 200 rows of data (is_train=False)
Entering the `_fit` method
Entering the `_preprocess` method: 800 rows of data (is_train=True)
Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0, 'max_depth': 30, 'max_features': 0.16393245237809825, 'criterion': 'entropy'}
Exiting the `_fit` method
Entering the `_preprocess` method: 200 rows of data (is_train=False)
Entering the `_fit` method
Entering the `_preprocess` method: 800 rows of data (is_train=True)
Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0, 'max_depth': 25, 'max_features': 0.11819655769629316, 'criterion': 'entropy'}
Exiting the `_fit` method
Entering the `_preprocess` method: 200 rows of data (is_train=False)
Entering the `_fit` method
Entering the `_preprocess` method: 800 rows of data (is_train=True)
Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0, 'max_depth': 10, 'max_features': 0.8003410758548655, 'criterion': 'entropy'}
Exiting the `_fit` method
Entering the `_preprocess` method: 200 rows of data (is_train=False)
Entering the `_fit` method
Entering the `_preprocess` method: 800 rows of data (is_train=True)
Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0, 'max_depth': 5, 'max_features': 0.9807565080094875, 'criterion': 'entropy'}
Exiting the `_fit` method
Entering the `_preprocess` method: 200 rows of data (is_train=False)
Entering the `_fit` method
Entering the `_preprocess` method: 800 rows of data (is_train=True)
Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0, 'max_depth': 22, 'max_features': 0.5153314260276387, 'criterion': 'gini'}
Exiting the `_fit` method
Entering the `_preprocess` method: 200 rows of data (is_train=False)
Entering the `_fit` method
Entering the `_preprocess` method: 800 rows of data (is_train=True)
Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0, 'max_depth': 24, 'max_features': 0.20644698328203992, 'criterion': 'entropy'}
Exiting the `_fit` method
Entering the `_preprocess` method: 200 rows of data (is_train=False)
Entering the `_fit` method
Entering the `_preprocess` method: 800 rows of data (is_train=True)
Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0, 'max_depth': 6, 'max_features': 0.22901795866814179, 'criterion': 'gini'}
Exiting the `_fit` method
Entering the `_preprocess` method: 200 rows of data (is_train=False)
Entering the `_fit` method
Entering the `_preprocess` method: 800 rows of data (is_train=True)
Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0, 'max_depth': 5, 'max_features': 0.5696634895750645, 'criterion': 'entropy'}
Exiting the `_fit` method
Entering the `_preprocess` method: 200 rows of data (is_train=False)
Entering the `_fit` method
Entering the `_preprocess` method: 800 rows of data (is_train=True)
Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0, 'max_depth': 28, 'max_features': 0.3381000508941643, 'criterion': 'gini'}
Exiting the `_fit` method
Entering the `_preprocess` method: 200 rows of data (is_train=False)
Entering the `_fit` method
Entering the `_preprocess` method: 800 rows of data (is_train=True)
Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0, 'max_depth': 23, 'max_features': 0.5105352989948937, 'criterion': 'entropy'}
Exiting the `_fit` method
Entering the `_preprocess` method: 200 rows of data (is_train=False)
Entering the `_fit` method
Entering the `_preprocess` method: 800 rows of data (is_train=True)
Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0, 'max_depth': 5, 'max_features': 0.11691082039271963, 'criterion': 'gini'}
Exiting the `_fit` method
Entering the `_preprocess` method: 200 rows of data (is_train=False)
Entering the `_fit` method
Entering the `_preprocess` method: 800 rows of data (is_train=True)
Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0, 'max_depth': 10, 'max_features': 0.6508861504501793, 'criterion': 'entropy'}
Exiting the `_fit` method
Entering the `_preprocess` method: 200 rows of data (is_train=False)
Entering the `_fit` method
Entering the `_preprocess` method: 800 rows of data (is_train=True)
Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0, 'max_depth': 22, 'max_features': 0.9493732706631618, 'criterion': 'gini'}
Exiting the `_fit` method
Entering the `_preprocess` method: 200 rows of data (is_train=False)
Entering the `_fit` method
Entering the `_preprocess` method: 800 rows of data (is_train=True)
Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0, 'max_depth': 15, 'max_features': 0.42355711051640743, 'criterion': 'entropy'}
Exiting the `_fit` method
Entering the `_preprocess` method: 200 rows of data (is_train=False)
Entering the `_fit` method
Entering the `_preprocess` method: 800 rows of data (is_train=True)
Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0, 'max_depth': 6, 'max_features': 0.7278680763345383, 'criterion': 'gini'}
Exiting the `_fit` method
Entering the `_preprocess` method: 200 rows of data (is_train=False)
Entering the `_fit` method
Entering the `_preprocess` method: 800 rows of data (is_train=True)
Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0, 'max_depth': 30, 'max_features': 0.7000900439011009, 'criterion': 'entropy'}
Exiting the `_fit` method
Entering the `_preprocess` method: 200 rows of data (is_train=False)
Entering the `_fit` method
Entering the `_preprocess` method: 800 rows of data (is_train=True)
Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0, 'max_depth': 16, 'max_features': 0.2893443049664568, 'criterion': 'entropy'}
Exiting the `_fit` method
Entering the `_preprocess` method: 200 rows of data (is_train=False)
Entering the `_fit` method
Entering the `_preprocess` method: 800 rows of data (is_train=True)
Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0, 'max_depth': 5, 'max_features': 0.38388551583176544, 'criterion': 'entropy'}
Exiting the `_fit` method
Entering the `_preprocess` method: 200 rows of data (is_train=False)
Entering the `_fit` method
Entering the `_preprocess` method: 800 rows of data (is_train=True)
Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0, 'max_depth': 17, 'max_features': 0.6131770933760917, 'criterion': 'entropy'}
Exiting the `_fit` method
Entering the `_preprocess` method: 200 rows of data (is_train=False)
Entering the `_fit` method
Entering the `_preprocess` method: 800 rows of data (is_train=True)
Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0, 'max_depth': 16, 'max_features': 0.9895364542533036, 'criterion': 'gini'}
Exiting the `_fit` method
Entering the `_preprocess` method: 200 rows of data (is_train=False)
Entering the `_preprocess` method: 200 rows of data (is_train=False)
Entering the `_preprocess` method: 200 rows of data (is_train=False)
Entering the `_preprocess` method: 200 rows of data (is_train=False)
Entering the `_preprocess` method: 200 rows of data (is_train=False)
Entering the `_preprocess` method: 200 rows of data (is_train=False)
Entering the `_preprocess` method: 200 rows of data (is_train=False)
Entering the `_preprocess` method: 200 rows of data (is_train=False)
Entering the `_preprocess` method: 200 rows of data (is_train=False)
Entering the `_preprocess` method: 200 rows of data (is_train=False)
Entering the `_preprocess` method: 200 rows of data (is_train=False)
Entering the `_preprocess` method: 200 rows of data (is_train=False)
Entering the `_preprocess` method: 200 rows of data (is_train=False)
Entering the `_preprocess` method: 200 rows of data (is_train=False)
Entering the `_preprocess` method: 200 rows of data (is_train=False)
Entering the `_preprocess` method: 200 rows of data (is_train=False)
Entering the `_preprocess` method: 200 rows of data (is_train=False)
Entering the `_preprocess` method: 200 rows of data (is_train=False)
Entering the `_preprocess` method: 200 rows of data (is_train=False)
Entering the `_preprocess` method: 200 rows of data (is_train=False)
Entering the `_preprocess` method: 200 rows of data (is_train=False)
Entering the `_preprocess` method: 200 rows of data (is_train=False)
Entering the `_preprocess` method: 200 rows of data (is_train=False)
Entering the `_preprocess` method: 200 rows of data (is_train=False)
Entering the `_preprocess` method: 200 rows of data (is_train=False)
Entering the `_preprocess` method: 200 rows of data (is_train=False)
Entering the `_preprocess` method: 200 rows of data (is_train=False)
Entering the `_preprocess` method: 200 rows of data (is_train=False)
No path specified. Models will be saved in: "AutogluonModels/ag-20250508_205608"
Verbosity: 2 (Standard Logging)
=================== System Info ===================
AutoGluon Version: 1.3.1b20250508
Python Version: 3.11.9
Operating System: Linux
Platform Machine: x86_64
Platform Version: #1 SMP Wed Mar 12 14:53:59 UTC 2025
CPU Count: 8
Memory Avail: 28.72 GB / 30.95 GB (92.8%)
Disk Space Avail: 212.06 GB / 255.99 GB (82.8%)
===================================================
No presets specified! To achieve strong results with AutoGluon, it is recommended to use the available presets. Defaulting to `'medium'`...
Recommended Presets (For more details refer to https://autogluon.cn/stable/tutorials/tabular/tabular-essentials.html#presets):
presets='experimental' : New in v1.2: Pre-trained foundation model + parallel fits. The absolute best accuracy without consideration for inference speed. Does not support GPU.
presets='best' : Maximize accuracy. Recommended for most users. Use in competitions and benchmarks.
presets='high' : Strong accuracy with fast inference speed.
presets='good' : Good accuracy with very fast inference speed.
presets='medium' : Fast training time, ideal for initial prototyping.
Warning: hyperparameter tuning is currently experimental and may cause the process to hang.
Beginning AutoGluon training ... Time limit = 20s
AutoGluon will save models to "/home/ci/autogluon/docs/tutorials/tabular/advanced/AutogluonModels/ag-20250508_205608"
Train Data Rows: 1000
Train Data Columns: 14
Label Column: class
AutoGluon infers your prediction problem is: 'binary' (because only two unique label-values observed).
2 unique label values: [' >50K', ' <=50K']
If 'binary' is not the correct problem_type, please manually specify the problem_type parameter during Predictor init (You may specify problem_type as one of: ['binary', 'multiclass', 'regression', 'quantile'])
Problem Type: binary
Preprocessing data ...
Selected class <--> label mapping: class 1 = >50K, class 0 = <=50K
Note: For your binary classification, AutoGluon arbitrarily selected which label-value represents positive ( >50K) vs negative ( <=50K) class.
To explicitly set the positive_class, either rename classes to 1 and 0, or specify positive_class in Predictor init.
Using Feature Generators to preprocess the data ...
Fitting AutoMLPipelineFeatureGenerator...
Available Memory: 29414.06 MB
Train Data (Original) Memory Usage: 0.56 MB (0.0% of available memory)
Inferring data type of each feature based on column values. Set feature_metadata_in to manually specify special dtypes of the features.
Stage 1 Generators:
Fitting AsTypeFeatureGenerator...
Note: Converting 1 features to boolean dtype as they only contain 2 unique values.
Stage 2 Generators:
Fitting FillNaFeatureGenerator...
Stage 3 Generators:
Fitting IdentityFeatureGenerator...
Fitting CategoryFeatureGenerator...
Fitting CategoryMemoryMinimizeFeatureGenerator...
Stage 4 Generators:
Fitting DropUniqueFeatureGenerator...
Stage 5 Generators:
Fitting DropDuplicatesFeatureGenerator...
Types of features in original data (raw dtype, special dtypes):
('int', []) : 6 | ['age', 'fnlwgt', 'education-num', 'capital-gain', 'capital-loss', ...]
('object', []) : 8 | ['workclass', 'education', 'marital-status', 'occupation', 'relationship', ...]
Types of features in processed data (raw dtype, special dtypes):
('category', []) : 7 | ['workclass', 'education', 'marital-status', 'occupation', 'relationship', ...]
('int', []) : 6 | ['age', 'fnlwgt', 'education-num', 'capital-gain', 'capital-loss', ...]
('int', ['bool']) : 1 | ['sex']
0.1s = Fit runtime
14 features in original data used to generate 14 features in processed data.
Train Data (Processed) Memory Usage: 0.06 MB (0.0% of available memory)
Data preprocessing and feature engineering runtime = 0.09s ...
AutoGluon will gauge predictive performance using evaluation metric: 'accuracy'
To change this, specify the eval_metric parameter of Predictor()
Automatically generating train/validation split with holdout_frac=0.2, Train Rows: 800, Val Rows: 200
User-specified model hyperparameters to be fit:
{
'<class '__main__.CustomRandomForestModel'>': [{'max_depth': Int: lower=5, upper=30, 'max_features': Real: lower=0.1, upper=1.0, 'criterion': Categorical['gini', 'entropy']}],
}
Custom Model Type Detected: <class '__main__.CustomRandomForestModel'>
Fitting 1 L1 models, fit_strategy="sequential" ...
Hyperparameter tuning model: CustomRandomForestModel ... Tuning model for up to 17.92s of the 19.91s of remaining time.
Stopping HPO to satisfy time limit...
Fitted model: CustomRandomForestModel/T1 ...
0.805 = Validation score (accuracy)
0.51s = Training runtime
0.06s = Validation runtime
Fitted model: CustomRandomForestModel/T2 ...
0.835 = Validation score (accuracy)
0.54s = Training runtime
0.06s = Validation runtime
Fitted model: CustomRandomForestModel/T3 ...
0.825 = Validation score (accuracy)
0.53s = Training runtime
0.06s = Validation runtime
Fitted model: CustomRandomForestModel/T4 ...
0.855 = Validation score (accuracy)
0.53s = Training runtime
0.06s = Validation runtime
Fitted model: CustomRandomForestModel/T5 ...
0.835 = Validation score (accuracy)
0.52s = Training runtime
0.06s = Validation runtime
Fitted model: CustomRandomForestModel/T6 ...
0.83 = Validation score (accuracy)
0.52s = Training runtime
0.06s = Validation runtime
Fitted model: CustomRandomForestModel/T7 ...
0.845 = Validation score (accuracy)
0.53s = Training runtime
0.06s = Validation runtime
Fitted model: CustomRandomForestModel/T8 ...
0.845 = Validation score (accuracy)
0.53s = Training runtime
0.06s = Validation runtime
Fitted model: CustomRandomForestModel/T9 ...
0.835 = Validation score (accuracy)
0.52s = Training runtime
0.06s = Validation runtime
Fitted model: CustomRandomForestModel/T10 ...
0.845 = Validation score (accuracy)
0.53s = Training runtime
0.06s = Validation runtime
Fitted model: CustomRandomForestModel/T11 ...
0.85 = Validation score (accuracy)
0.52s = Training runtime
0.06s = Validation runtime
Fitted model: CustomRandomForestModel/T12 ...
0.835 = Validation score (accuracy)
0.53s = Training runtime
0.06s = Validation runtime
Fitted model: CustomRandomForestModel/T13 ...
0.84 = Validation score (accuracy)
0.52s = Training runtime
0.06s = Validation runtime
Fitted model: CustomRandomForestModel/T14 ...
0.835 = Validation score (accuracy)
0.52s = Training runtime
0.06s = Validation runtime
Fitted model: CustomRandomForestModel/T15 ...
0.845 = Validation score (accuracy)
0.51s = Training runtime
0.06s = Validation runtime
Fitted model: CustomRandomForestModel/T16 ...
0.85 = Validation score (accuracy)
0.51s = Training runtime
0.06s = Validation runtime
Fitted model: CustomRandomForestModel/T17 ...
0.85 = Validation score (accuracy)
0.53s = Training runtime
0.06s = Validation runtime
Fitted model: CustomRandomForestModel/T18 ...
0.805 = Validation score (accuracy)
0.51s = Training runtime
0.06s = Validation runtime
Fitted model: CustomRandomForestModel/T19 ...
0.845 = Validation score (accuracy)
0.53s = Training runtime
0.06s = Validation runtime
Fitted model: CustomRandomForestModel/T20 ...
0.835 = Validation score (accuracy)
0.53s = Training runtime
0.06s = Validation runtime
Fitted model: CustomRandomForestModel/T21 ...
0.85 = Validation score (accuracy)
0.52s = Training runtime
0.05s = Validation runtime
Fitted model: CustomRandomForestModel/T22 ...
0.83 = Validation score (accuracy)
0.53s = Training runtime
0.06s = Validation runtime
Fitted model: CustomRandomForestModel/T23 ...
0.84 = Validation score (accuracy)
0.53s = Training runtime
0.06s = Validation runtime
Fitted model: CustomRandomForestModel/T24 ...
0.845 = Validation score (accuracy)
0.52s = Training runtime
0.06s = Validation runtime
Fitted model: CustomRandomForestModel/T25 ...
0.845 = Validation score (accuracy)
0.59s = Training runtime
0.06s = Validation runtime
Fitted model: CustomRandomForestModel/T26 ...
0.845 = Validation score (accuracy)
0.53s = Training runtime
0.06s = Validation runtime
Fitted model: CustomRandomForestModel/T27 ...
0.835 = Validation score (accuracy)
0.54s = Training runtime
0.06s = Validation runtime
Fitting model: WeightedEnsemble_L2 ... Training model for up to 19.91s of the -0.11s of remaining time.
Ensemble Weights: {'CustomRandomForestModel/T4': 1.0}
0.855 = Validation score (accuracy)
0.0s = Training runtime
0.0s = Validation runtime
AutoGluon training complete, total runtime = 20.15s ... Best model: WeightedEnsemble_L2 | Estimated inference throughput: 3432.0 rows/s (200 batch size)
Disabling decision threshold calibration for metric `accuracy` due to having fewer than 10000 rows of validation data for calibration, to avoid overfitting (200 rows).
`accuracy` is generally not improved through threshold calibration. Force calibration via specifying `calibrate_decision_threshold=True`.
TabularPredictor saved. To load, use: predictor = TabularPredictor.load("/home/ci/autogluon/docs/tutorials/tabular/advanced/AutogluonModels/ag-20250508_205608")
预测器排行榜 (HPO)¶
HPO 运行的排行榜将显示名称中带有后缀 '/Tx'
的模型。这表示它们是在哪个 HPO 试验中执行的。
leaderboard_hpo = predictor.leaderboard()
leaderboard_hpo
model | score_val | eval_metric | pred_time_val | fit_time | pred_time_val_marginal | fit_time_marginal | stack_level | can_infer | fit_order | |
---|---|---|---|---|---|---|---|---|---|---|
0 | CustomRandomForestModel/T4 | 0.855 | accuracy | 0.057505 | 0.527691 | 0.057505 | 0.527691 | 1 | True | 4 |
1 | WeightedEnsemble_L2 | 0.855 | accuracy | 0.058275 | 0.530208 | 0.000770 | 0.002517 | 2 | True | 28 |
2 | CustomRandomForestModel/T21 | 0.850 | accuracy | 0.049113 | 0.516646 | 0.049113 | 0.516646 | 1 | True | 21 |
3 | CustomRandomForestModel/T16 | 0.850 | accuracy | 0.057665 | 0.509290 | 0.057665 | 0.509290 | 1 | True | 16 |
4 | CustomRandomForestModel/T17 | 0.850 | accuracy | 0.058555 | 0.527065 | 0.058555 | 0.527065 | 1 | True | 17 |
5 | CustomRandomForestModel/T11 | 0.850 | accuracy | 0.060928 | 0.523556 | 0.060928 | 0.523556 | 1 | True | 11 |
6 | CustomRandomForestModel/T24 | 0.845 | accuracy | 0.057226 | 0.519612 | 0.057226 | 0.519612 | 1 | True | 24 |
7 | CustomRandomForestModel/T25 | 0.845 | accuracy | 0.057345 | 0.590406 | 0.057345 | 0.590406 | 1 | True | 25 |
8 | CustomRandomForestModel/T19 | 0.845 | accuracy | 0.057515 | 0.530065 | 0.057515 | 0.530065 | 1 | True | 19 |
9 | CustomRandomForestModel/T26 | 0.845 | accuracy | 0.057742 | 0.528203 | 0.057742 | 0.528203 | 1 | True | 26 |
10 | CustomRandomForestModel/T15 | 0.845 | accuracy | 0.057770 | 0.507537 | 0.057770 | 0.507537 | 1 | True | 15 |
11 | CustomRandomForestModel/T7 | 0.845 | accuracy | 0.058108 | 0.525845 | 0.058108 | 0.525845 | 1 | True | 7 |
12 | CustomRandomForestModel/T10 | 0.845 | accuracy | 0.058333 | 0.533082 | 0.058333 | 0.533082 | 1 | True | 10 |
13 | CustomRandomForestModel/T8 | 0.845 | accuracy | 0.059577 | 0.528987 | 0.059577 | 0.528987 | 1 | True | 8 |
14 | CustomRandomForestModel/T23 | 0.840 | accuracy | 0.057624 | 0.528516 | 0.057624 | 0.528516 | 1 | True | 23 |
15 | CustomRandomForestModel/T13 | 0.840 | accuracy | 0.057679 | 0.519133 | 0.057679 | 0.519133 | 1 | True | 13 |
16 | CustomRandomForestModel/T20 | 0.835 | accuracy | 0.057330 | 0.534520 | 0.057330 | 0.534520 | 1 | True | 20 |
17 | CustomRandomForestModel/T12 | 0.835 | accuracy | 0.057552 | 0.529909 | 0.057552 | 0.529909 | 1 | True | 12 |
18 | CustomRandomForestModel/T14 | 0.835 | accuracy | 0.057713 | 0.516292 | 0.057713 | 0.516292 | 1 | True | 14 |
19 | CustomRandomForestModel/T2 | 0.835 | accuracy | 0.058793 | 0.535451 | 0.058793 | 0.535451 | 1 | True | 2 |
20 | CustomRandomForestModel/T27 | 0.835 | accuracy | 0.058857 | 0.536120 | 0.058857 | 0.536120 | 1 | True | 27 |
21 | CustomRandomForestModel/T5 | 0.835 | accuracy | 0.058996 | 0.515126 | 0.058996 | 0.515126 | 1 | True | 5 |
22 | CustomRandomForestModel/T9 | 0.835 | accuracy | 0.060557 | 0.519985 | 0.060557 | 0.519985 | 1 | True | 9 |
23 | CustomRandomForestModel/T22 | 0.830 | accuracy | 0.058770 | 0.525597 | 0.058770 | 0.525597 | 1 | True | 22 |
24 | CustomRandomForestModel/T6 | 0.830 | accuracy | 0.058837 | 0.524674 | 0.058837 | 0.524674 | 1 | True | 6 |
25 | CustomRandomForestModel/T3 | 0.825 | accuracy | 0.057525 | 0.526406 | 0.057525 | 0.526406 | 1 | True | 3 |
26 | CustomRandomForestModel/T1 | 0.805 | accuracy | 0.057927 | 0.506985 | 0.057927 | 0.506985 | 1 | True | 1 |
27 | CustomRandomForestModel/T18 | 0.805 | accuracy | 0.058503 | 0.513918 | 0.058503 | 0.513918 | 1 | True | 18 |
获取训练好的模型的超参数¶
让我们获取验证得分最高的模型的超参数。
best_model_name = leaderboard_hpo[leaderboard_hpo['stack_level'] == 1]['model'].iloc[0]
predictor_info = predictor.info()
best_model_info = predictor_info['model_info'][best_model_name]
print(best_model_info)
print(f'Best Model Hyperparameters ({best_model_name}):')
print(best_model_info['hyperparameters'])
{'name': 'CustomRandomForestModel/T4', 'model_type': 'CustomRandomForestModel', 'problem_type': 'binary', 'eval_metric': 'accuracy', 'stopping_metric': 'accuracy', 'fit_time': 0.5276908874511719, 'num_classes': 2, 'quantile_levels': None, 'predict_time': 0.05750536918640137, 'val_score': 0.855, 'hyperparameters': {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0, 'max_depth': 26, 'max_features': 0.4459435365634299, 'criterion': 'entropy'}, 'hyperparameters_user': {'max_depth': 26, 'max_features': 0.4459435365634299, 'criterion': 'entropy', 'n_estimators': 300, 'n_jobs': -1, 'random_state': 0}, 'hyperparameters_fit': {}, 'hyperparameters_nondefault': ['max_depth', 'max_features', 'criterion', 'n_estimators', 'n_jobs', 'random_state'], 'ag_args_fit': {'max_memory_usage_ratio': 1.0, 'max_time_limit_ratio': 1.0, 'max_time_limit': None, 'min_time_limit': 0, 'valid_raw_types': ['int', 'float', 'category'], 'valid_special_types': None, 'ignored_type_group_special': None, 'ignored_type_group_raw': None, 'get_features_kwargs': None, 'get_features_kwargs_extra': None, 'predict_1_batch_size': None, 'temperature_scalar': None}, 'num_features': 14, 'features': ['age', 'fnlwgt', 'education-num', 'sex', 'capital-gain', 'capital-loss', 'hours-per-week', 'workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race', 'native-country'], 'feature_metadata': <autogluon.common.features.feature_metadata.FeatureMetadata object at 0x7f2ae329be50>, 'memory_size': 4803239, 'compile_time': None, 'is_initialized': True, 'is_fit': True, 'is_valid': True, 'can_infer': True, 'has_learning_curves': False, 'num_samples': 800, 'val_in_fit': True, 'unlabeled_in_fit': False, 'num_cpus': 8, 'num_gpus': 0.0}
Best Model Hyperparameters (CustomRandomForestModel/T4):
{'n_estimators': 300, 'n_jobs': -1, 'random_state': 0, 'max_depth': 26, 'max_features': 0.4459435365634299, 'criterion': 'entropy'}
使用 TabularPredictor 与其他模型一起训练自定义模型¶
最后,我们将自定义模型(具有调优的超参数)与默认的 AutoGluon 模型一起进行训练。
这只需要通过 get_hyperparameter_config
获取默认模型的超参数字典,并将 CustomRandomForestModel 添加为一个键。
from autogluon.tabular.configs.hyperparameter_configs import get_hyperparameter_config
# Now we can add the custom model with tuned hyperparameters to be trained alongside the default models:
custom_hyperparameters = get_hyperparameter_config('default')
custom_hyperparameters[CustomRandomForestModel] = best_model_info['hyperparameters']
print(custom_hyperparameters)
{'NN_TORCH': {}, 'GBM': [{'extra_trees': True, 'ag_args': {'name_suffix': 'XT'}}, {}, {'learning_rate': 0.03, 'num_leaves': 128, 'feature_fraction': 0.9, 'min_data_in_leaf': 3, 'ag_args': {'name_suffix': 'Large', 'priority': 0, 'hyperparameter_tune_kwargs': None}}], 'CAT': {}, 'XGB': {}, 'FASTAI': {}, 'RF': [{'criterion': 'gini', 'ag_args': {'name_suffix': 'Gini', 'problem_types': ['binary', 'multiclass']}}, {'criterion': 'entropy', 'ag_args': {'name_suffix': 'Entr', 'problem_types': ['binary', 'multiclass']}}, {'criterion': 'squared_error', 'ag_args': {'name_suffix': 'MSE', 'problem_types': ['regression', 'quantile']}}], 'XT': [{'criterion': 'gini', 'ag_args': {'name_suffix': 'Gini', 'problem_types': ['binary', 'multiclass']}}, {'criterion': 'entropy', 'ag_args': {'name_suffix': 'Entr', 'problem_types': ['binary', 'multiclass']}}, {'criterion': 'squared_error', 'ag_args': {'name_suffix': 'MSE', 'problem_types': ['regression', 'quantile']}}], 'KNN': [{'weights': 'uniform', 'ag_args': {'name_suffix': 'Unif'}}, {'weights': 'distance', 'ag_args': {'name_suffix': 'Dist'}}], <class '__main__.CustomRandomForestModel'>: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0, 'max_depth': 26, 'max_features': 0.4459435365634299, 'criterion': 'entropy'}}
predictor = TabularPredictor(label=label).fit(train_data, hyperparameters=custom_hyperparameters) # Train the default models plus a single tuned CustomRandomForestModel
# predictor = TabularPredictor(label=label).fit(train_data, hyperparameters=custom_hyperparameters, presets='best_quality') # We can even use the custom model in a multi-layer stack ensemble
predictor.leaderboard(test_data)
Entering the `_fit` method
Entering the `_preprocess` method: 800 rows of data (is_train=True)
Hyperparameters: {'n_estimators': 300, 'n_jobs': -1, 'random_state': 0, 'max_depth': 26, 'max_features': 0.4459435365634299, 'criterion': 'entropy'}
Exiting the `_fit` method
Entering the `_preprocess` method: 200 rows of data (is_train=False)
Entering the `_preprocess` method: 9769 rows of data (is_train=False)
No path specified. Models will be saved in: "AutogluonModels/ag-20250508_205630"
Verbosity: 2 (Standard Logging)
=================== System Info ===================
AutoGluon Version: 1.3.1b20250508
Python Version: 3.11.9
Operating System: Linux
Platform Machine: x86_64
Platform Version: #1 SMP Wed Mar 12 14:53:59 UTC 2025
CPU Count: 8
Memory Avail: 28.70 GB / 30.95 GB (92.7%)
Disk Space Avail: 211.97 GB / 255.99 GB (82.8%)
===================================================
No presets specified! To achieve strong results with AutoGluon, it is recommended to use the available presets. Defaulting to `'medium'`...
Recommended Presets (For more details refer to https://autogluon.cn/stable/tutorials/tabular/tabular-essentials.html#presets):
presets='experimental' : New in v1.2: Pre-trained foundation model + parallel fits. The absolute best accuracy without consideration for inference speed. Does not support GPU.
presets='best' : Maximize accuracy. Recommended for most users. Use in competitions and benchmarks.
presets='high' : Strong accuracy with fast inference speed.
presets='good' : Good accuracy with very fast inference speed.
presets='medium' : Fast training time, ideal for initial prototyping.
Beginning AutoGluon training ...
AutoGluon will save models to "/home/ci/autogluon/docs/tutorials/tabular/advanced/AutogluonModels/ag-20250508_205630"
Train Data Rows: 1000
Train Data Columns: 14
Label Column: class
AutoGluon infers your prediction problem is: 'binary' (because only two unique label-values observed).
2 unique label values: [' >50K', ' <=50K']
If 'binary' is not the correct problem_type, please manually specify the problem_type parameter during Predictor init (You may specify problem_type as one of: ['binary', 'multiclass', 'regression', 'quantile'])
Problem Type: binary
Preprocessing data ...
Selected class <--> label mapping: class 1 = >50K, class 0 = <=50K
Note: For your binary classification, AutoGluon arbitrarily selected which label-value represents positive ( >50K) vs negative ( <=50K) class.
To explicitly set the positive_class, either rename classes to 1 and 0, or specify positive_class in Predictor init.
Using Feature Generators to preprocess the data ...
Fitting AutoMLPipelineFeatureGenerator...
Available Memory: 29388.29 MB
Train Data (Original) Memory Usage: 0.56 MB (0.0% of available memory)
Inferring data type of each feature based on column values. Set feature_metadata_in to manually specify special dtypes of the features.
Stage 1 Generators:
Fitting AsTypeFeatureGenerator...
Note: Converting 1 features to boolean dtype as they only contain 2 unique values.
Stage 2 Generators:
Fitting FillNaFeatureGenerator...
Stage 3 Generators:
Fitting IdentityFeatureGenerator...
Fitting CategoryFeatureGenerator...
Fitting CategoryMemoryMinimizeFeatureGenerator...
Stage 4 Generators:
Fitting DropUniqueFeatureGenerator...
Stage 5 Generators:
Fitting DropDuplicatesFeatureGenerator...
Types of features in original data (raw dtype, special dtypes):
('int', []) : 6 | ['age', 'fnlwgt', 'education-num', 'capital-gain', 'capital-loss', ...]
('object', []) : 8 | ['workclass', 'education', 'marital-status', 'occupation', 'relationship', ...]
Types of features in processed data (raw dtype, special dtypes):
('category', []) : 7 | ['workclass', 'education', 'marital-status', 'occupation', 'relationship', ...]
('int', []) : 6 | ['age', 'fnlwgt', 'education-num', 'capital-gain', 'capital-loss', ...]
('int', ['bool']) : 1 | ['sex']
0.1s = Fit runtime
14 features in original data used to generate 14 features in processed data.
Train Data (Processed) Memory Usage: 0.06 MB (0.0% of available memory)
Data preprocessing and feature engineering runtime = 0.08s ...
AutoGluon will gauge predictive performance using evaluation metric: 'accuracy'
To change this, specify the eval_metric parameter of Predictor()
Automatically generating train/validation split with holdout_frac=0.2, Train Rows: 800, Val Rows: 200
User-specified model hyperparameters to be fit:
{
'NN_TORCH': [{}],
'GBM': [{'extra_trees': True, 'ag_args': {'name_suffix': 'XT'}}, {}, {'learning_rate': 0.03, 'num_leaves': 128, 'feature_fraction': 0.9, 'min_data_in_leaf': 3, 'ag_args': {'name_suffix': 'Large', 'priority': 0, 'hyperparameter_tune_kwargs': None}}],
'CAT': [{}],
'XGB': [{}],
'FASTAI': [{}],
'RF': [{'criterion': 'gini', 'ag_args': {'name_suffix': 'Gini', 'problem_types': ['binary', 'multiclass']}}, {'criterion': 'entropy', 'ag_args': {'name_suffix': 'Entr', 'problem_types': ['binary', 'multiclass']}}, {'criterion': 'squared_error', 'ag_args': {'name_suffix': 'MSE', 'problem_types': ['regression', 'quantile']}}],
'XT': [{'criterion': 'gini', 'ag_args': {'name_suffix': 'Gini', 'problem_types': ['binary', 'multiclass']}}, {'criterion': 'entropy', 'ag_args': {'name_suffix': 'Entr', 'problem_types': ['binary', 'multiclass']}}, {'criterion': 'squared_error', 'ag_args': {'name_suffix': 'MSE', 'problem_types': ['regression', 'quantile']}}],
'KNN': [{'weights': 'uniform', 'ag_args': {'name_suffix': 'Unif'}}, {'weights': 'distance', 'ag_args': {'name_suffix': 'Dist'}}],
'<class '__main__.CustomRandomForestModel'>': [{'n_estimators': 300, 'n_jobs': -1, 'random_state': 0, 'max_depth': 26, 'max_features': 0.4459435365634299, 'criterion': 'entropy'}],
}
Custom Model Type Detected: <class '__main__.CustomRandomForestModel'>
Fitting 14 L1 models, fit_strategy="sequential" ...
Fitting model: KNeighborsUnif ...
0.725 = Validation score (accuracy)
0.03s = Training runtime
0.01s = Validation runtime
Fitting model: KNeighborsDist ...
0.71 = Validation score (accuracy)
0.01s = Training runtime
0.01s = Validation runtime
Fitting model: LightGBMXT ...
0.85 = Validation score (accuracy)
0.31s = Training runtime
0.0s = Validation runtime
Fitting model: LightGBM ...
0.84 = Validation score (accuracy)
0.29s = Training runtime
0.0s = Validation runtime
Fitting model: RandomForestGini ...
0.84 = Validation score (accuracy)
0.71s = Training runtime
0.06s = Validation runtime
Fitting model: RandomForestEntr ...
0.835 = Validation score (accuracy)
0.61s = Training runtime
0.06s = Validation runtime
Fitting model: CatBoost ...
0.86 = Validation score (accuracy)
1.87s = Training runtime
0.0s = Validation runtime
Fitting model: ExtraTreesGini ...
0.815 = Validation score (accuracy)
0.65s = Training runtime
0.06s = Validation runtime
Fitting model: ExtraTreesEntr ...
0.82 = Validation score (accuracy)
0.63s = Training runtime
0.06s = Validation runtime
Fitting model: NeuralNetFastAI ...
No improvement since epoch 7: early stopping
0.84 = Validation score (accuracy)
3.11s = Training runtime
0.01s = Validation runtime
Fitting model: XGBoost ...
0.855 = Validation score (accuracy)
0.44s = Training runtime
0.01s = Validation runtime
Fitting model: NeuralNetTorch ...
0.855 = Validation score (accuracy)
3.63s = Training runtime
0.01s = Validation runtime
Fitting model: LightGBMLarge ...
0.795 = Validation score (accuracy)
0.78s = Training runtime
0.0s = Validation runtime
Fitting model: CustomRandomForestModel ...
0.855 = Validation score (accuracy)
0.55s = Training runtime
0.06s = Validation runtime
Fitting model: WeightedEnsemble_L2 ...
Ensemble Weights: {'LightGBMXT': 0.267, 'CustomRandomForestModel': 0.267, 'CatBoost': 0.2, 'RandomForestGini': 0.133, 'ExtraTreesEntr': 0.133}
0.885 = Validation score (accuracy)
0.1s = Training runtime
0.0s = Validation runtime
AutoGluon training complete, total runtime = 14.36s ... Best model: WeightedEnsemble_L2 | Estimated inference throughput: 1094.9 rows/s (200 batch size)
Disabling decision threshold calibration for metric `accuracy` due to having fewer than 10000 rows of validation data for calibration, to avoid overfitting (200 rows).
`accuracy` is generally not improved through threshold calibration. Force calibration via specifying `calibrate_decision_threshold=True`.
TabularPredictor saved. To load, use: predictor = TabularPredictor.load("/home/ci/autogluon/docs/tutorials/tabular/advanced/AutogluonModels/ag-20250508_205630")
model | score_test | score_val | eval_metric | pred_time_test | pred_time_val | fit_time | pred_time_test_marginal | pred_time_val_marginal | fit_time_marginal | stack_level | can_infer | fit_order | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | CatBoost | 0.852902 | 0.860 | accuracy | 0.012000 | 0.004439 | 1.870484 | 0.012000 | 0.004439 | 1.870484 | 1 | True | 7 |
1 | WeightedEnsemble_L2 | 0.850957 | 0.885 | accuracy | 0.396561 | 0.182661 | 4.161953 | 0.003787 | 0.000768 | 0.095763 | 2 | True | 15 |
2 | LightGBMXT | 0.850752 | 0.850 | accuracy | 0.022345 | 0.003736 | 0.307004 | 0.022345 | 0.003736 | 0.307004 | 1 | True | 3 |
3 | NeuralNetFastAI | 0.848193 | 0.840 | accuracy | 0.138275 | 0.008939 | 3.112128 | 0.138275 | 0.008939 | 3.112128 | 1 | True | 10 |
4 | XGBoost | 0.846658 | 0.855 | accuracy | 0.048345 | 0.005962 | 0.435826 | 0.048345 | 0.005962 | 0.435826 | 1 | True | 11 |
5 | LightGBM | 0.841335 | 0.840 | accuracy | 0.016278 | 0.003465 | 0.287818 | 0.016278 | 0.003465 | 0.287818 | 1 | True | 4 |
6 | RandomForestGini | 0.840004 | 0.840 | accuracy | 0.125455 | 0.058398 | 0.714421 | 0.125455 | 0.058398 | 0.714421 | 1 | True | 5 |
7 | RandomForestEntr | 0.837240 | 0.835 | accuracy | 0.109810 | 0.057897 | 0.614643 | 0.109810 | 0.057897 | 0.614643 | 1 | True | 6 |
8 | CustomRandomForestModel | 0.834988 | 0.855 | accuracy | 0.118730 | 0.057171 | 0.546502 | 0.118730 | 0.057171 | 0.546502 | 1 | True | 14 |
9 | NeuralNetTorch | 0.833248 | 0.855 | accuracy | 0.058589 | 0.010674 | 3.629005 | 0.058589 | 0.010674 | 3.629005 | 1 | True | 12 |
10 | ExtraTreesGini | 0.831917 | 0.815 | accuracy | 0.110340 | 0.057254 | 0.648913 | 0.110340 | 0.057254 | 0.648913 | 1 | True | 8 |
11 | LightGBMLarge | 0.829461 | 0.795 | accuracy | 0.066144 | 0.004574 | 0.783208 | 0.066144 | 0.004574 | 0.783208 | 1 | True | 13 |
12 | ExtraTreesEntr | 0.829358 | 0.820 | accuracy | 0.114245 | 0.058149 | 0.627778 | 0.114245 | 0.058149 | 0.627778 | 1 | True | 9 |
13 | KNeighborsUnif | 0.744600 | 0.725 | accuracy | 0.025678 | 0.014041 | 0.031765 | 0.025678 | 0.014041 | 0.031765 | 1 | True | 1 |
14 | KNeighborsDist | 0.710922 | 0.710 | accuracy | 0.025847 | 0.014201 | 0.011127 | 0.025847 | 0.014201 | 0.011127 | 1 | True | 2 |
总结¶
将自定义模型添加到 AutoGluon 就是这么简单。如果您创建了自定义模型,请考虑提交一个 PR,以便我们可以将其正式添加到 AutoGluon 中!
有关更多教程,请参阅预测表格中的列 - 快速入门和预测表格中的列 - 深入了解。
有关高级自定义模型的教程,请参阅将自定义模型添加到 AutoGluon(进阶))