添加自定义时间序列预测模型¶
本教程介绍如何添加自定义预测模型,该模型可以与默认预测模型一起进行训练、超参数调优和集成。
作为示例,我们将为 NeuralForecast 库中的 NHITS 模型实现一个 AutoGluon 包装器。
本教程包含以下部分:
实现模型包装器。
加载并预处理用于模型开发的数据集。
在独立模式下使用自定义模型。
在
TimeSeriesPredictor
内部使用自定义模型。
警告
本教程专为 AutoGluon 高级用户设计。
自定义模型实现高度依赖 AutoGluon 的私有 API,该 API 可能会随时间变化。因此,随着升级到新版本的 AutoGluon,可能需要更新您的自定义模型实现。
首先,我们安装 NeuralForecast 库,其中包含本教程中使用的自定义模型的实现。
pip install -q neuralforecast==2.0
Note: you may need to restart the kernel to use updated packages.
实现自定义模型¶
要实现自定义模型,我们需要创建 AbstractTimeSeriesModel
类的一个子类。该子类必须实现两个方法:_fit
和 _predict
。对于需要自定义预处理逻辑(例如,处理缺失值)的模型,我们还需要实现 preprocess
方法。
请查看以下代码并阅读注释,以了解自定义模型包装器的不同组件。
import logging
import pprint
from typing import Optional, Tuple
import pandas as pd
from autogluon.timeseries import TimeSeriesDataFrame
from autogluon.timeseries.models.abstract import AbstractTimeSeriesModel
from autogluon.timeseries.utils.warning_filters import warning_filter
# Optional - disable annoying PyTorch-Lightning loggers
for logger_name in [
"lightning.pytorch.utilities.rank_zero",
"pytorch_lightning.accelerators.cuda",
"lightning_fabric.utilities.seed",
]:
logging.getLogger(logger_name).setLevel(logging.ERROR)
class NHITSModel(AbstractTimeSeriesModel):
"""AutoGluon-compatible wrapper for the NHITS model from NeuralForecast."""
# Set these attributes to ensure that AutoGluon passes correct features to the model
_supports_known_covariates: bool = True
_supports_past_covariates: bool = True
_supports_static_features: bool = True
def preprocess(
self,
data: TimeSeriesDataFrame,
known_covariates: Optional[TimeSeriesDataFrame] = None,
is_train: bool = False,
**kwargs,
) -> Tuple[TimeSeriesDataFrame, Optional[TimeSeriesDataFrame]]:
"""Method that implements model-specific preprocessing logic.
This method is called on all data that is passed to `_fit` and `_predict` methods.
"""
# NeuralForecast cannot handle missing values represented by NaN. Therefore, we
# need to impute them before the data is passed to the model. First, we
# forward-fill and backward-fill all time series
data = data.fill_missing_values()
# Some time series might consist completely of missing values, so the previous
# line has no effect on them. We fill them with 0.0
data = data.fill_missing_values(method="constant", value=0.0)
# Some models (e.g., Chronos) can natively handle NaNs - for them we don't need
# to define a custom preprocessing logic
return data, known_covariates
def _get_default_hyperparameters(self) -> dict:
"""Default hyperparameters that will be provided to the inner model, i.e., the
NHITS implementation in neuralforecast. """
import torch
from neuralforecast.losses.pytorch import MQLoss
default_hyperparameters = dict(
loss=MQLoss(quantiles=self.quantile_levels),
input_size=2 * self.prediction_length,
scaler_type="standard",
enable_progress_bar=False,
enable_model_summary=False,
logger=False,
accelerator="cpu",
# The model wrapper should handle any time series length - even time series
# with 1 observation
start_padding_enabled=True,
# NeuralForecast requires that names of the past/future/static covariates are
# passed as model arguments. AutoGluon models have access to this information
# using the `metadata` attribute that is set automatically at model creation.
#
# Note that NeuralForecast does not support categorical covariates, so we
# only use the real-valued covariates here. To use categorical features in
# you wrapper, you need to either use techniques like one-hot-encoding, or
# rely on models that natively handle categorical features.
futr_exog_list=self.covariate_metadata.known_covariates_real,
hist_exog_list=self.covariate_metadata.past_covariates_real,
stat_exog_list=self.covariate_metadata.static_features_real,
)
if torch.cuda.is_available():
default_hyperparameters["accelerator"] = "gpu"
default_hyperparameters["devices"] = 1
return default_hyperparameters
def _fit(
self,
train_data: TimeSeriesDataFrame,
val_data: Optional[TimeSeriesDataFrame] = None,
time_limit: Optional[float] = None,
**kwargs,
) -> None:
"""Fit the model on the available training data."""
print("Entering the `_fit` method")
# We lazily import other libraries inside the _fit method. This reduces the
# import time for autogluon and ensures that even if one model has some problems
# with dependencies, the training process won't crash
from neuralforecast import NeuralForecast
from neuralforecast.models import NHITS
# It's important to ensure that the model respects the time_limit during `fit`.
# Since NeuralForecast is based on PyTorch-Lightning, this can be easily enforced
# using the `max_time` argument to `pl.Trainer`. For other model types such as
# ARIMA implementing the time_limit logic may require a lot of work.
hyperparameter_overrides = {}
if time_limit is not None:
hyperparameter_overrides = {"max_time": {"seconds": time_limit}}
# The method `get_hyperparameters()` returns the model hyperparameters in
# `_get_default_hyperparameters` overridden with the hyperparameters provided by the user in
# `predictor.fit(..., hyperparameters={NHITSModel: {}})`. We override these with other
# hyperparameters available at training time.
model_params = self.get_hyperparameters() | hyperparameter_overrides
print(f"Hyperparameters:\n{pprint.pformat(model_params, sort_dicts=False)}")
model = NHITS(h=self.prediction_length, **model_params)
self.nf = NeuralForecast(models=[model], freq=self.freq)
# Convert data into a format expected by the model. NeuralForecast expects time
# series data in pandas.DataFrame format that is quite similar to AutoGluon, so
# the transformation is very easy.
#
# Note that the `preprocess` method was already applied to train_data and val_data.
train_df, static_df = self._to_neuralforecast_format(train_data)
self.nf.fit(
train_df,
static_df=static_df,
id_col="item_id",
time_col="timestamp",
target_col=self.target,
)
print("Exiting the `_fit` method")
def _to_neuralforecast_format(self, data: TimeSeriesDataFrame) -> Tuple[pd.DataFrame, Optional[pd.DataFrame]]:
"""Convert a TimeSeriesDataFrame to the format expected by NeuralForecast."""
df = data.to_data_frame().reset_index()
# Drop the categorical covariates to avoid NeuralForecast errors
df = df.drop(columns=self.covariate_metadata.covariates_cat)
static_df = data.static_features
if len(self.covariate_metadata.static_features_real) > 0:
static_df = static_df.reset_index()
static_df = static_df.drop(columns=self.covariate_metadata.static_features_cat)
return df, static_df
def _predict(
self,
data: TimeSeriesDataFrame,
known_covariates: Optional[TimeSeriesDataFrame] = None,
**kwargs,
) -> TimeSeriesDataFrame:
"""Predict future target given the historical time series data and the future values of known_covariates."""
print("Entering the `_predict` method")
from neuralforecast.losses.pytorch import quantiles_to_outputs
df, static_df = self._to_neuralforecast_format(data)
if len(self.covariate_metadata.known_covariates_real) > 0:
futr_df, _ = self._to_neuralforecast_format(known_covariates)
else:
futr_df = None
with warning_filter():
predictions = self.nf.predict(df, static_df=static_df, futr_df=futr_df)
# predictions must be a TimeSeriesDataFrame with columns
# ["mean"] + [str(q) for q in self.quantile_levels]
model_name = str(self.nf.models[0])
rename_columns = {
f"{model_name}{suffix}": str(quantile)
for quantile, suffix in zip(*quantiles_to_outputs(self.quantile_levels))
}
predictions = predictions.rename(columns=rename_columns)
predictions["mean"] = predictions["0.5"]
predictions = TimeSeriesDataFrame(predictions)
return predictions
为了方便起见,这里概述了不同方法的输入和输出的主要限制。
_fit
和_predict
方法接收的输入数据满足:索引按
(item_id, timestamp)
排序观测值的时间戳具有与
self.freq
对应的规则频率列
self.target
包含时间序列的目标值目标列可能包含由
NaN
表示的缺失值数据可能包含协变量(包括静态特征),其 schema 在
self.covariate_metadata
中描述实值协变量的数据类型为
float32
分类协变量的数据类型为
category
协变量不包含任何缺失值
静态特征(如果存在)可通过
data.static_features
获取
_predict
返回的预测结果必须满足:以
TimeSeriesDataFrame
对象返回预测结果预测结果包含列
["mean"] + [str(q) for q in self.quantile_levels]
,分别包含点预测和分位数预测结果预测结果的索引恰好包含
data
中每个时间序列的self.prediction_length
个未来时间步长预测时间戳的频率与
self.freq
匹配预测结果的索引按
(item_id, timestamp)
排序预测结果不包含由
NaN
表示的缺失值,也没有间隙
如果提供了
time_limit
,则_fit
方法的运行时长不应超过time_limit
秒。任何方法都不应就地修改数据。如果需要修改,请先创建数据的副本。
即使某些时间序列全部为 NaN 或只有一个观测值,所有方法也应能正常工作。
我们现在将以两种模式使用此包装器:
独立模式(在
TimeSeriesPredictor
外部)。此模式应用于开发和调试。在这种情况下,我们需要手动处理预处理和模型配置。
在
TimeSeriesPredictor
内部。此模式可轻松将自定义模型与 AutoGluon 中可用的其他模型结合和比较。编写自定义模型包装器的主要目的是在这种模式下使用它。
加载和预处理数据¶
首先,我们加载将用于开发和评估的 Grocery Sales 数据集。
from autogluon.timeseries import TimeSeriesDataFrame
raw_data = TimeSeriesDataFrame.from_path(
"https://autogluon.s3.amazonaws.com/datasets/timeseries/grocery_sales/test.csv",
static_features_path="https://autogluon.s3.amazonaws.com/datasets/timeseries/grocery_sales/static.csv",
)
raw_data.head()
scaled_price | promotion_email | promotion_homepage | unit_sales | ||
---|---|---|---|---|---|
item_id | timestamp | ||||
1062_101 | 2018-01-01 | 0.879130 | 0.0 | 0.0 | 636.0 |
2018-01-08 | 0.994517 | 0.0 | 0.0 | 123.0 | |
2018-01-15 | 1.005513 | 0.0 | 0.0 | 391.0 | |
2018-01-22 | 1.000000 | 0.0 | 0.0 | 339.0 | |
2018-01-29 | 0.883309 | 0.0 | 0.0 | 661.0 |
raw_data.static_features.head()
product_code | product_category | product_subcategory | location_code | |
---|---|---|---|---|
item_id | ||||
1062_101 | 1062 | 饮料 | 芒果汁 | 101 |
1062_102 | 1062 | 饮料 | 芒果汁 | 102 |
1062_104 | 1062 | 饮料 | 芒果汁 | 104 |
1062_106 | 1062 | 饮料 | 芒果汁 | 106 |
1062_108 | 1062 | 饮料 | 芒果汁 | 108 |
print("Types of the columns in raw data:")
print(raw_data.dtypes)
print("\nTypes of the columns in raw static features:")
print(raw_data.static_features.dtypes)
print("\nNumber of missing values per column:")
print(raw_data.isna().sum())
Types of the columns in raw data:
scaled_price float64
promotion_email float64
promotion_homepage float64
unit_sales float64
dtype: object
Types of the columns in raw static features:
product_code int64
product_category object
product_subcategory object
location_code int64
dtype: object
Number of missing values per column:
scaled_price 714
promotion_email 714
promotion_homepage 714
unit_sales 714
dtype: int64
定义预测任务
prediction_length = 7 # number of future steps to predict
target = "unit_sales" # target column
known_covariates_names = ["promotion_email", "promotion_homepage"] # covariates known in the future
在独立模式下使用模型之前,我们需要对数据应用通用的 AutoGluon 预处理。
TimeSeriesFeatureGenerator
包含数据类型标准化和协变量缺失值插补等预处理步骤。
from autogluon.timeseries.utils.features import TimeSeriesFeatureGenerator
feature_generator = TimeSeriesFeatureGenerator(target=target, known_covariates_names=known_covariates_names)
data = feature_generator.fit_transform(raw_data)
print("Types of the columns in preprocessed data:")
print(data.dtypes)
print("\nTypes of the columns in preprocessed static features:")
print(data.static_features.dtypes)
print("\nNumber of missing values per column:")
print(data.isna().sum())
Types of the columns in preprocessed data:
unit_sales float64
promotion_email float32
promotion_homepage float32
scaled_price float32
dtype: object
Types of the columns in preprocessed static features:
product_category category
product_subcategory category
product_code float32
location_code float32
dtype: object
Number of missing values per column:
unit_sales 714
promotion_email 0
promotion_homepage 0
scaled_price 0
dtype: int64
在独立模式下使用自定义模型¶
在独立模式下使用模型有助于调试我们的实现。一旦确保所有方法都按预期工作,我们将在 TimeSeriesPredictor
内部使用模型。
训练¶
现在我们准备好在预处理后的数据上训练自定义模型了。
在独立模式下使用模型时,我们需要手动配置其参数。
model = NHITSModel(
prediction_length=prediction_length,
target=target,
covariate_metadata=feature_generator.covariate_metadata,
freq=data.freq,
quantile_levels=[0.1, 0.5, 0.9],
)
model.fit(train_data=data, time_limit=20)
Entering the `_fit` method
Hyperparameters:
{'loss': MQLoss(),
'input_size': 14,
'scaler_type': 'standard',
'enable_progress_bar': False,
'enable_model_summary': False,
'logger': False,
'accelerator': 'gpu',
'start_padding_enabled': True,
'futr_exog_list': ['promotion_email', 'promotion_homepage'],
'hist_exog_list': ['scaled_price'],
'stat_exog_list': ['product_code', 'location_code'],
'devices': 1,
'max_time': {'seconds': 13.536751490100324}}
Exiting the `_fit` method
NHITS
预测和评分¶
past_data, known_covariates = data.get_model_inputs_for_scoring(
prediction_length=prediction_length,
known_covariates_names=known_covariates_names,
)
predictions = model.predict(past_data, known_covariates)
predictions.head()
Entering the `_predict` method
0.1 | 0.5 | 0.9 | mean | ||
---|---|---|---|---|---|
item_id | timestamp | ||||
1062_101 | 2018-06-18 | 216.337616 | 340.998993 | 484.123535 | 340.998993 |
2018-06-25 | 204.024643 | 352.983673 | 524.839233 | 352.983673 | |
2018-07-02 | 209.871796 | 364.116669 | 548.581787 | 364.116669 | |
2018-07-09 | 205.984467 | 364.956635 | 551.808472 | 364.956635 | |
2018-07-16 | 208.200226 | 364.006592 | 548.709839 | 364.006592 |
model.score(data)
Entering the `_predict` method
np.float64(-0.3316095466210152)
在 TimeSeriesPredictor
内部使用自定义模型¶
确保自定义模型在独立模式下正常工作后,我们可以将其与其他模型一起传递给 TimeSeriesPredictor。
from autogluon.timeseries import TimeSeriesPredictor
train_data, test_data = raw_data.train_test_split(prediction_length)
predictor = TimeSeriesPredictor(
prediction_length=prediction_length,
target=target,
known_covariates_names=known_covariates_names,
)
predictor.fit(
train_data,
hyperparameters={
"Naive": {},
"Chronos": {"model_path": "bolt_small"},
"ETS": {},
NHITSModel: {},
},
time_limit=120,
)
Beginning AutoGluon training... Time limit = 120s
AutoGluon will save models to '/home/ci/autogluon/docs/tutorials/timeseries/advanced/AutogluonModels/ag-20250508_204919'
=================== System Info ===================
AutoGluon Version: 1.3.1b20250508
Python Version: 3.11.9
Operating System: Linux
Platform Machine: x86_64
Platform Version: #1 SMP Wed Mar 12 14:53:59 UTC 2025
CPU Count: 8
GPU Count: 1
Memory Avail: 27.81 GB / 30.95 GB (89.8%)
Disk Space Avail: 212.13 GB / 255.99 GB (82.9%)
===================================================
Fitting with arguments:
{'enable_ensemble': True,
'eval_metric': WQL,
'hyperparameters': {<class '__main__.NHITSModel'>: {},
'Chronos': {'model_path': 'bolt_small'},
'ETS': {},
'Naive': {}},
'known_covariates_names': ['promotion_email', 'promotion_homepage'],
'num_val_windows': 1,
'prediction_length': 7,
'quantile_levels': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],
'random_seed': 123,
'refit_every_n_windows': 1,
'refit_full': False,
'skip_model_selection': False,
'target': 'unit_sales',
'time_limit': 120,
'verbosity': 2}
Inferred time series frequency: 'W-MON'
Provided train_data has 7656 rows (NaN fraction=6.8%), 319 time series. Median time series length is 24 (min=24, max=24).
Provided data contains following columns:
target: 'unit_sales'
known_covariates:
categorical: []
continuous (float): ['promotion_email', 'promotion_homepage']
past_covariates:
categorical: []
continuous (float): ['scaled_price']
static_features:
categorical: ['product_category', 'product_subcategory']
continuous (float): ['product_code', 'location_code']
To learn how to fix incorrectly inferred types, please see documentation for TimeSeriesPredictor.fit
AutoGluon will gauge predictive performance using evaluation metric: 'WQL'
This metric's sign has been flipped to adhere to being higher_is_better. The metric score can be multiplied by -1 to get the metric value.
===================================================
Starting training. Start time is 2025-05-08 20:49:19
Models that will be trained: ['Naive', 'ETS', 'Chronos[bolt_small]', 'NHITS']
Training timeseries model Naive. Training for up to 24.0s of the 119.9s of remaining time.
-0.5412 = Validation score (-WQL)
0.03 s = Training runtime
1.86 s = Validation (prediction) runtime
Training timeseries model ETS. Training for up to 29.5s of the 118.0s of remaining time.
-0.7039 = Validation score (-WQL)
0.04 s = Training runtime
0.85 s = Validation (prediction) runtime
Training timeseries model Chronos[bolt_small]. Training for up to 39.0s of the 117.1s of remaining time.
-0.3320 = Validation score (-WQL)
0.59 s = Training runtime
1.41 s = Validation (prediction) runtime
Training timeseries model NHITS. Training for up to 57.6s of the 115.1s of remaining time.
-0.4681 = Validation score (-WQL)
19.03 s = Training runtime
0.10 s = Validation (prediction) runtime
Fitting simple weighted ensemble.
Ensemble weights: {'Chronos[bolt_small]': np.float64(0.97), 'NHITS': np.float64(0.03)}
/home/ci/autogluon/timeseries/src/autogluon/timeseries/metrics/abstract.py:101: FutureWarning: Passing `prediction_length` to `TimeSeriesScorer.__call__` is deprecated and will be removed in v2.0. Please set the `eval_metric.prediction_length` attribute instead.
warnings.warn(
-0.3320 = Validation score (-WQL)
0.52 s = Training runtime
1.51 s = Validation (prediction) runtime
Training complete. Models trained: ['Naive', 'ETS', 'Chronos[bolt_small]', 'NHITS', 'WeightedEnsemble']
Total runtime: 24.50 s
Best model: Chronos[bolt_small]
Best model score: -0.3320
Entering the `_fit` method
Hyperparameters:
{'loss': MQLoss(),
'input_size': 14,
'scaler_type': 'standard',
'enable_progress_bar': False,
'enable_model_summary': False,
'logger': False,
'accelerator': 'gpu',
'start_padding_enabled': True,
'futr_exog_list': ['promotion_email', 'promotion_homepage'],
'hist_exog_list': ['scaled_price'],
'stat_exog_list': ['product_code', 'location_code'],
'devices': 1,
'max_time': {'seconds': 51.80114328768496}}
Exiting the `_fit` method
Entering the `_predict` method
<autogluon.timeseries.predictor.TimeSeriesPredictor at 0x7f7e582a6650>
请注意,当我们在预测器内部使用自定义模型时,无需担心以下事项:
手动配置模型(设置
freq
,prediction_length
)使用
TimeSeriesFeatureGenerator
预处理数据设置时间限制
TimeSeriesPredictor
会自动处理上述所有方面。
我们还可以轻松地将自定义模型与预测器训练的其他模型进行比较。
predictor.leaderboard(test_data)
Entering the `_predict` method
Additional data provided, testing on additional data. Resulting leaderboard will be sorted according to test score (`score_test`).
/home/ci/autogluon/timeseries/src/autogluon/timeseries/metrics/abstract.py:101: FutureWarning: Passing `prediction_length` to `TimeSeriesScorer.__call__` is deprecated and will be removed in v2.0. Please set the `eval_metric.prediction_length` attribute instead.
warnings.warn(
/home/ci/autogluon/timeseries/src/autogluon/timeseries/metrics/abstract.py:101: FutureWarning: Passing `prediction_length` to `TimeSeriesScorer.__call__` is deprecated and will be removed in v2.0. Please set the `eval_metric.prediction_length` attribute instead.
warnings.warn(
/home/ci/autogluon/timeseries/src/autogluon/timeseries/metrics/abstract.py:101: FutureWarning: Passing `prediction_length` to `TimeSeriesScorer.__call__` is deprecated and will be removed in v2.0. Please set the `eval_metric.prediction_length` attribute instead.
warnings.warn(
/home/ci/autogluon/timeseries/src/autogluon/timeseries/metrics/abstract.py:101: FutureWarning: Passing `prediction_length` to `TimeSeriesScorer.__call__` is deprecated and will be removed in v2.0. Please set the `eval_metric.prediction_length` attribute instead.
warnings.warn(
/home/ci/autogluon/timeseries/src/autogluon/timeseries/metrics/abstract.py:101: FutureWarning: Passing `prediction_length` to `TimeSeriesScorer.__call__` is deprecated and will be removed in v2.0. Please set the `eval_metric.prediction_length` attribute instead.
warnings.warn(
模型 | score_test | score_val | pred_time_test | pred_time_val | fit_time_marginal | fit_order | |
---|---|---|---|---|---|---|---|
0 | WeightedEnsemble | -0.314278 | -0.332025 | 0.666137 | 1.507256 | 0.515570 | 5 |
1 | Chronos[bolt_small] | -0.315786 | -0.331984 | 0.531369 | 1.405814 | 0.590614 | 3 |
2 | NHITS | -0.397008 | -0.468127 | 0.132401 | 0.101442 | 19.026050 | 4 |
3 | ETS | -0.459021 | -0.703868 | 0.250249 | 0.845771 | 0.036806 | 2 |
4 | Naive | -0.512205 | -0.541231 | 0.188219 | 1.860332 | 0.033384 | 1 |
我们还可以利用预测器的其他功能,例如 feature_importance
。
predictor.feature_importance(test_data, model="NHITS")
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Computing feature importance
/home/ci/autogluon/timeseries/src/autogluon/timeseries/metrics/abstract.py:101: FutureWarning: Passing `prediction_length` to `TimeSeriesScorer.__call__` is deprecated and will be removed in v2.0. Please set the `eval_metric.prediction_length` attribute instead.
warnings.warn(
/home/ci/autogluon/timeseries/src/autogluon/timeseries/metrics/abstract.py:101: FutureWarning: Passing `prediction_length` to `TimeSeriesScorer.__call__` is deprecated and will be removed in v2.0. Please set the `eval_metric.prediction_length` attribute instead.
warnings.warn(
/home/ci/autogluon/timeseries/src/autogluon/timeseries/metrics/abstract.py:101: FutureWarning: Passing `prediction_length` to `TimeSeriesScorer.__call__` is deprecated and will be removed in v2.0. Please set the `eval_metric.prediction_length` attribute instead.
warnings.warn(
/home/ci/autogluon/timeseries/src/autogluon/timeseries/metrics/abstract.py:101: FutureWarning: Passing `prediction_length` to `TimeSeriesScorer.__call__` is deprecated and will be removed in v2.0. Please set the `eval_metric.prediction_length` attribute instead.
warnings.warn(
/home/ci/autogluon/timeseries/src/autogluon/timeseries/metrics/abstract.py:101: FutureWarning: Passing `prediction_length` to `TimeSeriesScorer.__call__` is deprecated and will be removed in v2.0. Please set the `eval_metric.prediction_length` attribute instead.
warnings.warn(
/home/ci/autogluon/timeseries/src/autogluon/timeseries/metrics/abstract.py:101: FutureWarning: Passing `prediction_length` to `TimeSeriesScorer.__call__` is deprecated and will be removed in v2.0. Please set the `eval_metric.prediction_length` attribute instead.
warnings.warn(
/home/ci/autogluon/timeseries/src/autogluon/timeseries/metrics/abstract.py:101: FutureWarning: Passing `prediction_length` to `TimeSeriesScorer.__call__` is deprecated and will be removed in v2.0. Please set the `eval_metric.prediction_length` attribute instead.
warnings.warn(
/home/ci/autogluon/timeseries/src/autogluon/timeseries/metrics/abstract.py:101: FutureWarning: Passing `prediction_length` to `TimeSeriesScorer.__call__` is deprecated and will be removed in v2.0. Please set the `eval_metric.prediction_length` attribute instead.
warnings.warn(
/home/ci/autogluon/timeseries/src/autogluon/timeseries/metrics/abstract.py:101: FutureWarning: Passing `prediction_length` to `TimeSeriesScorer.__call__` is deprecated and will be removed in v2.0. Please set the `eval_metric.prediction_length` attribute instead.
warnings.warn(
/home/ci/autogluon/timeseries/src/autogluon/timeseries/metrics/abstract.py:101: FutureWarning: Passing `prediction_length` to `TimeSeriesScorer.__call__` is deprecated and will be removed in v2.0. Please set the `eval_metric.prediction_length` attribute instead.
warnings.warn(
/home/ci/autogluon/timeseries/src/autogluon/timeseries/metrics/abstract.py:101: FutureWarning: Passing `prediction_length` to `TimeSeriesScorer.__call__` is deprecated and will be removed in v2.0. Please set the `eval_metric.prediction_length` attribute instead.
warnings.warn(
/home/ci/autogluon/timeseries/src/autogluon/timeseries/metrics/abstract.py:101: FutureWarning: Passing `prediction_length` to `TimeSeriesScorer.__call__` is deprecated and will be removed in v2.0. Please set the `eval_metric.prediction_length` attribute instead.
warnings.warn(
/home/ci/autogluon/timeseries/src/autogluon/timeseries/metrics/abstract.py:101: FutureWarning: Passing `prediction_length` to `TimeSeriesScorer.__call__` is deprecated and will be removed in v2.0. Please set the `eval_metric.prediction_length` attribute instead.
warnings.warn(
/home/ci/autogluon/timeseries/src/autogluon/timeseries/metrics/abstract.py:101: FutureWarning: Passing `prediction_length` to `TimeSeriesScorer.__call__` is deprecated and will be removed in v2.0. Please set the `eval_metric.prediction_length` attribute instead.
warnings.warn(
/home/ci/autogluon/timeseries/src/autogluon/timeseries/metrics/abstract.py:101: FutureWarning: Passing `prediction_length` to `TimeSeriesScorer.__call__` is deprecated and will be removed in v2.0. Please set the `eval_metric.prediction_length` attribute instead.
warnings.warn(
/home/ci/autogluon/timeseries/src/autogluon/timeseries/metrics/abstract.py:101: FutureWarning: Passing `prediction_length` to `TimeSeriesScorer.__call__` is deprecated and will be removed in v2.0. Please set the `eval_metric.prediction_length` attribute instead.
warnings.warn(
/home/ci/autogluon/timeseries/src/autogluon/timeseries/metrics/abstract.py:101: FutureWarning: Passing `prediction_length` to `TimeSeriesScorer.__call__` is deprecated and will be removed in v2.0. Please set the `eval_metric.prediction_length` attribute instead.
warnings.warn(
/home/ci/autogluon/timeseries/src/autogluon/timeseries/metrics/abstract.py:101: FutureWarning: Passing `prediction_length` to `TimeSeriesScorer.__call__` is deprecated and will be removed in v2.0. Please set the `eval_metric.prediction_length` attribute instead.
warnings.warn(
/home/ci/autogluon/timeseries/src/autogluon/timeseries/metrics/abstract.py:101: FutureWarning: Passing `prediction_length` to `TimeSeriesScorer.__call__` is deprecated and will be removed in v2.0. Please set the `eval_metric.prediction_length` attribute instead.
warnings.warn(
/home/ci/autogluon/timeseries/src/autogluon/timeseries/metrics/abstract.py:101: FutureWarning: Passing `prediction_length` to `TimeSeriesScorer.__call__` is deprecated and will be removed in v2.0. Please set the `eval_metric.prediction_length` attribute instead.
warnings.warn(
/home/ci/autogluon/timeseries/src/autogluon/timeseries/metrics/abstract.py:101: FutureWarning: Passing `prediction_length` to `TimeSeriesScorer.__call__` is deprecated and will be removed in v2.0. Please set the `eval_metric.prediction_length` attribute instead.
warnings.warn(
/home/ci/autogluon/timeseries/src/autogluon/timeseries/metrics/abstract.py:101: FutureWarning: Passing `prediction_length` to `TimeSeriesScorer.__call__` is deprecated and will be removed in v2.0. Please set the `eval_metric.prediction_length` attribute instead.
warnings.warn(
/home/ci/autogluon/timeseries/src/autogluon/timeseries/metrics/abstract.py:101: FutureWarning: Passing `prediction_length` to `TimeSeriesScorer.__call__` is deprecated and will be removed in v2.0. Please set the `eval_metric.prediction_length` attribute instead.
warnings.warn(
/home/ci/autogluon/timeseries/src/autogluon/timeseries/metrics/abstract.py:101: FutureWarning: Passing `prediction_length` to `TimeSeriesScorer.__call__` is deprecated and will be removed in v2.0. Please set the `eval_metric.prediction_length` attribute instead.
warnings.warn(
/home/ci/autogluon/timeseries/src/autogluon/timeseries/metrics/abstract.py:101: FutureWarning: Passing `prediction_length` to `TimeSeriesScorer.__call__` is deprecated and will be removed in v2.0. Please set the `eval_metric.prediction_length` attribute instead.
warnings.warn(
/home/ci/autogluon/timeseries/src/autogluon/timeseries/metrics/abstract.py:101: FutureWarning: Passing `prediction_length` to `TimeSeriesScorer.__call__` is deprecated and will be removed in v2.0. Please set the `eval_metric.prediction_length` attribute instead.
warnings.warn(
/home/ci/autogluon/timeseries/src/autogluon/timeseries/metrics/abstract.py:101: FutureWarning: Passing `prediction_length` to `TimeSeriesScorer.__call__` is deprecated and will be removed in v2.0. Please set the `eval_metric.prediction_length` attribute instead.
warnings.warn(
/home/ci/autogluon/timeseries/src/autogluon/timeseries/metrics/abstract.py:101: FutureWarning: Passing `prediction_length` to `TimeSeriesScorer.__call__` is deprecated and will be removed in v2.0. Please set the `eval_metric.prediction_length` attribute instead.
warnings.warn(
/home/ci/autogluon/timeseries/src/autogluon/timeseries/metrics/abstract.py:101: FutureWarning: Passing `prediction_length` to `TimeSeriesScorer.__call__` is deprecated and will be removed in v2.0. Please set the `eval_metric.prediction_length` attribute instead.
warnings.warn(
/home/ci/autogluon/timeseries/src/autogluon/timeseries/metrics/abstract.py:101: FutureWarning: Passing `prediction_length` to `TimeSeriesScorer.__call__` is deprecated and will be removed in v2.0. Please set the `eval_metric.prediction_length` attribute instead.
warnings.warn(
/home/ci/autogluon/timeseries/src/autogluon/timeseries/metrics/abstract.py:101: FutureWarning: Passing `prediction_length` to `TimeSeriesScorer.__call__` is deprecated and will be removed in v2.0. Please set the `eval_metric.prediction_length` attribute instead.
warnings.warn(
/home/ci/autogluon/timeseries/src/autogluon/timeseries/metrics/abstract.py:101: FutureWarning: Passing `prediction_length` to `TimeSeriesScorer.__call__` is deprecated and will be removed in v2.0. Please set the `eval_metric.prediction_length` attribute instead.
warnings.warn(
/home/ci/autogluon/timeseries/src/autogluon/timeseries/metrics/abstract.py:101: FutureWarning: Passing `prediction_length` to `TimeSeriesScorer.__call__` is deprecated and will be removed in v2.0. Please set the `eval_metric.prediction_length` attribute instead.
warnings.warn(
/home/ci/autogluon/timeseries/src/autogluon/timeseries/metrics/abstract.py:101: FutureWarning: Passing `prediction_length` to `TimeSeriesScorer.__call__` is deprecated and will be removed in v2.0. Please set the `eval_metric.prediction_length` attribute instead.
warnings.warn(
/home/ci/autogluon/timeseries/src/autogluon/timeseries/metrics/abstract.py:101: FutureWarning: Passing `prediction_length` to `TimeSeriesScorer.__call__` is deprecated and will be removed in v2.0. Please set the `eval_metric.prediction_length` attribute instead.
warnings.warn(
/home/ci/autogluon/timeseries/src/autogluon/timeseries/metrics/abstract.py:101: FutureWarning: Passing `prediction_length` to `TimeSeriesScorer.__call__` is deprecated and will be removed in v2.0. Please set the `eval_metric.prediction_length` attribute instead.
warnings.warn(
/home/ci/autogluon/timeseries/src/autogluon/timeseries/metrics/abstract.py:101: FutureWarning: Passing `prediction_length` to `TimeSeriesScorer.__call__` is deprecated and will be removed in v2.0. Please set the `eval_metric.prediction_length` attribute instead.
warnings.warn(
/home/ci/autogluon/timeseries/src/autogluon/timeseries/metrics/abstract.py:101: FutureWarning: Passing `prediction_length` to `TimeSeriesScorer.__call__` is deprecated and will be removed in v2.0. Please set the `eval_metric.prediction_length` attribute instead.
warnings.warn(
/home/ci/autogluon/timeseries/src/autogluon/timeseries/metrics/abstract.py:101: FutureWarning: Passing `prediction_length` to `TimeSeriesScorer.__call__` is deprecated and will be removed in v2.0. Please set the `eval_metric.prediction_length` attribute instead.
warnings.warn(
/home/ci/autogluon/timeseries/src/autogluon/timeseries/metrics/abstract.py:101: FutureWarning: Passing `prediction_length` to `TimeSeriesScorer.__call__` is deprecated and will be removed in v2.0. Please set the `eval_metric.prediction_length` attribute instead.
warnings.warn(
重要性 | 标准差 (stdev) | 样本数 (n) | p99_low (99% 置信区间下限) | p99_high (99% 置信区间上限) | |
---|---|---|---|---|---|
product_category | 0.000000 | 0.000000 | 5.0 | 0.000000 | 0.000000 |
product_subcategory | 0.000000 | 0.000000 | 5.0 | 0.000000 | 0.000000 |
product_code | -0.000775 | 0.010627 | 5.0 | -0.022656 | 0.021106 |
location_code | 0.000092 | 0.000285 | 5.0 | -0.000494 | 0.000679 |
promotion_email | 0.007981 | 0.009743 | 5.0 | -0.012079 | 0.028041 |
promotion_homepage | 0.005051 | 0.007467 | 5.0 | -0.010324 | 0.020426 |
scaled_price | -0.000088 | 0.000898 | 5.0 | -0.001937 | 0.001761 |
正如预期的那样,特征 product_category
和 product_subcategory
的重要性为零,因为我们的实现忽略了分类特征。
下面介绍如何使用不同的超参数配置训练自定义模型的多个版本
predictor = TimeSeriesPredictor(
prediction_length=prediction_length,
target=target,
known_covariates_names=known_covariates_names,
)
predictor.fit(
train_data,
hyperparameters={
NHITSModel: [
{}, # default hyperparameters
{"input_size": 20}, # custom input_size
{"scaler_type": "robust"}, # custom scaler_type
]
},
time_limit=60,
)
Entering the `_fit` method
Hyperparameters:
{'loss': MQLoss(),
'input_size': 14,
'scaler_type': 'standard',
'enable_progress_bar': False,
'enable_model_summary': False,
'logger': False,
'accelerator': 'gpu',
'start_padding_enabled': True,
'futr_exog_list': ['promotion_email', 'promotion_homepage'],
'hist_exog_list': ['scaled_price'],
'stat_exog_list': ['product_code', 'location_code'],
'devices': 1,
'max_time': {'seconds': 13.4735950648168}}
Exiting the `_fit` method
Entering the `_predict` method
Entering the `_fit` method
Hyperparameters:
{'loss': MQLoss(),
'input_size': 20,
'scaler_type': 'standard',
'enable_progress_bar': False,
'enable_model_summary': False,
'logger': False,
'accelerator': 'gpu',
'start_padding_enabled': True,
'futr_exog_list': ['promotion_email', 'promotion_homepage'],
'hist_exog_list': ['scaled_price'],
'stat_exog_list': ['product_code', 'location_code'],
'devices': 1,
'max_time': {'seconds': 13.842486143193675}}
Exiting the `_fit` method
Entering the `_predict` method
Entering the `_fit` method
Hyperparameters:
{'loss': MQLoss(),
'input_size': 14,
'scaler_type': 'robust',
'enable_progress_bar': False,
'enable_model_summary': False,
'logger': False,
'accelerator': 'gpu',
'start_padding_enabled': True,
'futr_exog_list': ['promotion_email', 'promotion_homepage'],
'hist_exog_list': ['scaled_price'],
'stat_exog_list': ['product_code', 'location_code'],
'devices': 1,
'max_time': {'seconds': 14.42103596553934}}
Exiting the `_fit` method
Entering the `_predict` method
Beginning AutoGluon training... Time limit = 60s
AutoGluon will save models to '/home/ci/autogluon/docs/tutorials/timeseries/advanced/AutogluonModels/ag-20250508_204948'
=================== System Info ===================
AutoGluon Version: 1.3.1b20250508
Python Version: 3.11.9
Operating System: Linux
Platform Machine: x86_64
Platform Version: #1 SMP Wed Mar 12 14:53:59 UTC 2025
CPU Count: 8
GPU Count: 1
Memory Avail: 26.97 GB / 30.95 GB (87.2%)
Disk Space Avail: 211.94 GB / 255.99 GB (82.8%)
===================================================
Fitting with arguments:
{'enable_ensemble': True,
'eval_metric': WQL,
'hyperparameters': {<class '__main__.NHITSModel'>: [{},
{'input_size': 20},
{'scaler_type': 'robust'}]},
'known_covariates_names': ['promotion_email', 'promotion_homepage'],
'num_val_windows': 1,
'prediction_length': 7,
'quantile_levels': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],
'random_seed': 123,
'refit_every_n_windows': 1,
'refit_full': False,
'skip_model_selection': False,
'target': 'unit_sales',
'time_limit': 60,
'verbosity': 2}
Inferred time series frequency: 'W-MON'
Provided train_data has 7656 rows (NaN fraction=6.8%), 319 time series. Median time series length is 24 (min=24, max=24).
Provided data contains following columns:
target: 'unit_sales'
known_covariates:
categorical: []
continuous (float): ['promotion_email', 'promotion_homepage']
past_covariates:
categorical: []
continuous (float): ['scaled_price']
static_features:
categorical: ['product_category', 'product_subcategory']
continuous (float): ['product_code', 'location_code']
To learn how to fix incorrectly inferred types, please see documentation for TimeSeriesPredictor.fit
AutoGluon will gauge predictive performance using evaluation metric: 'WQL'
This metric's sign has been flipped to adhere to being higher_is_better. The metric score can be multiplied by -1 to get the metric value.
===================================================
Starting training. Start time is 2025-05-08 20:49:48
Models that will be trained: ['NHITS', 'NHITS_2', 'NHITS_3']
Training timeseries model NHITS. Training for up to 15.0s of the 59.9s of remaining time.
-0.4648 = Validation score (-WQL)
13.62 s = Training runtime
0.10 s = Validation (prediction) runtime
Training timeseries model NHITS_2. Training for up to 15.4s of the 46.2s of remaining time.
-0.5210 = Validation score (-WQL)
13.98 s = Training runtime
0.10 s = Validation (prediction) runtime
Training timeseries model NHITS_3. Training for up to 16.0s of the 32.1s of remaining time.
-0.3959 = Validation score (-WQL)
14.57 s = Training runtime
0.10 s = Validation (prediction) runtime
Fitting simple weighted ensemble.
Ensemble weights: {'NHITS_3': np.float64(1.0)}
/home/ci/autogluon/timeseries/src/autogluon/timeseries/metrics/abstract.py:101: FutureWarning: Passing `prediction_length` to `TimeSeriesScorer.__call__` is deprecated and will be removed in v2.0. Please set the `eval_metric.prediction_length` attribute instead.
warnings.warn(
-0.3959 = Validation score (-WQL)
0.39 s = Training runtime
0.10 s = Validation (prediction) runtime
Training complete. Models trained: ['NHITS', 'NHITS_2', 'NHITS_3', 'WeightedEnsemble']
Total runtime: 42.99 s
Best model: NHITS_3
Best model score: -0.3959
<autogluon.timeseries.predictor.TimeSeriesPredictor at 0x7f7e581c4c10>
predictor.leaderboard(test_data)
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Additional data provided, testing on additional data. Resulting leaderboard will be sorted according to test score (`score_test`).
/home/ci/autogluon/timeseries/src/autogluon/timeseries/metrics/abstract.py:101: FutureWarning: Passing `prediction_length` to `TimeSeriesScorer.__call__` is deprecated and will be removed in v2.0. Please set the `eval_metric.prediction_length` attribute instead.
warnings.warn(
/home/ci/autogluon/timeseries/src/autogluon/timeseries/metrics/abstract.py:101: FutureWarning: Passing `prediction_length` to `TimeSeriesScorer.__call__` is deprecated and will be removed in v2.0. Please set the `eval_metric.prediction_length` attribute instead.
warnings.warn(
/home/ci/autogluon/timeseries/src/autogluon/timeseries/metrics/abstract.py:101: FutureWarning: Passing `prediction_length` to `TimeSeriesScorer.__call__` is deprecated and will be removed in v2.0. Please set the `eval_metric.prediction_length` attribute instead.
warnings.warn(
/home/ci/autogluon/timeseries/src/autogluon/timeseries/metrics/abstract.py:101: FutureWarning: Passing `prediction_length` to `TimeSeriesScorer.__call__` is deprecated and will be removed in v2.0. Please set the `eval_metric.prediction_length` attribute instead.
warnings.warn(
模型 | score_test | score_val | pred_time_test | pred_time_val | fit_time_marginal | fit_order | |
---|---|---|---|---|---|---|---|
0 | NHITS | -0.384574 | -0.464803 | 0.118471 | 0.097582 | 13.620753 | 1 |
1 | WeightedEnsemble | -0.422423 | -0.395856 | 0.125363 | 0.101789 | 0.394304 | 4 |
2 | NHITS_3 | -0.422423 | -0.395856 | 0.124666 | 0.101789 | 14.568996 | 3 |
3 | NHITS_2 | -0.465447 | -0.521006 | 0.120018 | 0.099501 | 13.980385 | 2 |
总结¶
这就是将自定义预测模型添加到 AutoGluon 所需的全部步骤。如果您创建了自定义模型,请考虑提交一个 PR,以便我们可以将其正式添加到 AutoGluon 中!
有关更多教程,请参阅时间序列预测 - 快速入门和时间序列预测 - 深入探讨。