添加自定义时间序列预测模型

Open In Colab Open In SageMaker Studio Lab

本教程介绍如何添加自定义预测模型,该模型可以与默认预测模型一起进行训练、超参数调优和集成。

作为示例,我们将为 NeuralForecast 库中的 NHITS 模型实现一个 AutoGluon 包装器。

本教程包含以下部分:

  1. 实现模型包装器。

  2. 加载并预处理用于模型开发的数据集。

  3. 在独立模式下使用自定义模型。

  4. TimeSeriesPredictor 内部使用自定义模型。

警告

本教程专为 AutoGluon 高级用户设计。

自定义模型实现高度依赖 AutoGluon 的私有 API,该 API 可能会随时间变化。因此,随着升级到新版本的 AutoGluon,可能需要更新您的自定义模型实现。

首先,我们安装 NeuralForecast 库,其中包含本教程中使用的自定义模型的实现。

pip install -q neuralforecast==2.0
Note: you may need to restart the kernel to use updated packages.

实现自定义模型

要实现自定义模型,我们需要创建 AbstractTimeSeriesModel 类的一个子类。该子类必须实现两个方法:_fit_predict。对于需要自定义预处理逻辑(例如,处理缺失值)的模型,我们还需要实现 preprocess 方法。

请查看以下代码并阅读注释,以了解自定义模型包装器的不同组件。

import logging
import pprint
from typing import Optional, Tuple

import pandas as pd

from autogluon.timeseries import TimeSeriesDataFrame
from autogluon.timeseries.models.abstract import AbstractTimeSeriesModel
from autogluon.timeseries.utils.warning_filters import warning_filter

# Optional - disable annoying PyTorch-Lightning loggers
for logger_name in [
    "lightning.pytorch.utilities.rank_zero",
    "pytorch_lightning.accelerators.cuda",
    "lightning_fabric.utilities.seed",
]:
    logging.getLogger(logger_name).setLevel(logging.ERROR)


class NHITSModel(AbstractTimeSeriesModel):
    """AutoGluon-compatible wrapper for the NHITS model from NeuralForecast."""

    # Set these attributes to ensure that AutoGluon passes correct features to the model
    _supports_known_covariates: bool = True
    _supports_past_covariates: bool = True
    _supports_static_features: bool = True

    def preprocess(
        self,
        data: TimeSeriesDataFrame,
        known_covariates: Optional[TimeSeriesDataFrame] = None,
        is_train: bool = False,
        **kwargs,
    ) -> Tuple[TimeSeriesDataFrame, Optional[TimeSeriesDataFrame]]:
        """Method that implements model-specific preprocessing logic.

        This method is called on all data that is passed to `_fit` and `_predict` methods.
        """
        # NeuralForecast cannot handle missing values represented by NaN. Therefore, we
        # need to impute them before the data is passed to the model. First, we
        # forward-fill and backward-fill all time series
        data = data.fill_missing_values()
        # Some time series might consist completely of missing values, so the previous
        # line has no effect on them. We fill them with 0.0
        data = data.fill_missing_values(method="constant", value=0.0)
        # Some models (e.g., Chronos) can natively handle NaNs - for them we don't need
        # to define a custom preprocessing logic
        return data, known_covariates

    def _get_default_hyperparameters(self) -> dict:
        """Default hyperparameters that will be provided to the inner model, i.e., the
        NHITS implementation in neuralforecast. """
        import torch
        from neuralforecast.losses.pytorch import MQLoss

        default_hyperparameters = dict(
            loss=MQLoss(quantiles=self.quantile_levels),
            input_size=2 * self.prediction_length,
            scaler_type="standard",
            enable_progress_bar=False,
            enable_model_summary=False,
            logger=False,
            accelerator="cpu",
            # The model wrapper should handle any time series length - even time series
            # with 1 observation
            start_padding_enabled=True,
            # NeuralForecast requires that names of the past/future/static covariates are
            # passed as model arguments. AutoGluon models have access to this information
            # using the `metadata` attribute that is set automatically at model creation.
            #
            # Note that NeuralForecast does not support categorical covariates, so we
            # only use the real-valued covariates here. To use categorical features in
            # you wrapper, you need to either use techniques like one-hot-encoding, or
            # rely on models that natively handle categorical features.
            futr_exog_list=self.covariate_metadata.known_covariates_real,
            hist_exog_list=self.covariate_metadata.past_covariates_real,
            stat_exog_list=self.covariate_metadata.static_features_real,
        )

        if torch.cuda.is_available():
            default_hyperparameters["accelerator"] = "gpu"
            default_hyperparameters["devices"] = 1

        return default_hyperparameters

    def _fit(
        self,
        train_data: TimeSeriesDataFrame,
        val_data: Optional[TimeSeriesDataFrame] = None,
        time_limit: Optional[float] = None,
        **kwargs,
    ) -> None:
        """Fit the model on the available training data."""
        print("Entering the `_fit` method")

        # We lazily import other libraries inside the _fit method. This reduces the
        # import time for autogluon and ensures that even if one model has some problems
        # with dependencies, the training process won't crash
        from neuralforecast import NeuralForecast
        from neuralforecast.models import NHITS

        # It's important to ensure that the model respects the time_limit during `fit`.
        # Since NeuralForecast is based on PyTorch-Lightning, this can be easily enforced
        # using the `max_time` argument to `pl.Trainer`. For other model types such as
        # ARIMA implementing the time_limit logic may require a lot of work.
        hyperparameter_overrides = {}
        if time_limit is not None:
            hyperparameter_overrides = {"max_time": {"seconds": time_limit}}

        # The method `get_hyperparameters()` returns the model hyperparameters in
        # `_get_default_hyperparameters` overridden with the hyperparameters provided by the user in
        # `predictor.fit(..., hyperparameters={NHITSModel: {}})`. We override these with other
        # hyperparameters available at training time.
        model_params = self.get_hyperparameters() | hyperparameter_overrides
        print(f"Hyperparameters:\n{pprint.pformat(model_params, sort_dicts=False)}")

        model = NHITS(h=self.prediction_length, **model_params)
        self.nf = NeuralForecast(models=[model], freq=self.freq)

        # Convert data into a format expected by the model. NeuralForecast expects time
        # series data in pandas.DataFrame format that is quite similar to AutoGluon, so
        # the transformation is very easy.
        #
        # Note that the `preprocess` method was already applied to train_data and val_data.
        train_df, static_df = self._to_neuralforecast_format(train_data)
        self.nf.fit(
            train_df,
            static_df=static_df,
            id_col="item_id",
            time_col="timestamp",
            target_col=self.target,
        )
        print("Exiting the `_fit` method")

    def _to_neuralforecast_format(self, data: TimeSeriesDataFrame) -> Tuple[pd.DataFrame, Optional[pd.DataFrame]]:
        """Convert a TimeSeriesDataFrame to the format expected by NeuralForecast."""
        df = data.to_data_frame().reset_index()
        # Drop the categorical covariates to avoid NeuralForecast errors
        df = df.drop(columns=self.covariate_metadata.covariates_cat)
        static_df = data.static_features
        if len(self.covariate_metadata.static_features_real) > 0:
            static_df = static_df.reset_index()
            static_df = static_df.drop(columns=self.covariate_metadata.static_features_cat)
        return df, static_df

    def _predict(
        self,
        data: TimeSeriesDataFrame,
        known_covariates: Optional[TimeSeriesDataFrame] = None,
        **kwargs,
    ) -> TimeSeriesDataFrame:
        """Predict future target given the historical time series data and the future values of known_covariates."""
        print("Entering the `_predict` method")

        from neuralforecast.losses.pytorch import quantiles_to_outputs

        df, static_df = self._to_neuralforecast_format(data)
        if len(self.covariate_metadata.known_covariates_real) > 0:
            futr_df, _ = self._to_neuralforecast_format(known_covariates)
        else:
            futr_df = None

        with warning_filter():
            predictions = self.nf.predict(df, static_df=static_df, futr_df=futr_df)

        # predictions must be a TimeSeriesDataFrame with columns
        # ["mean"] + [str(q) for q in self.quantile_levels]
        model_name = str(self.nf.models[0])
        rename_columns = {
            f"{model_name}{suffix}": str(quantile)
            for quantile, suffix in zip(*quantiles_to_outputs(self.quantile_levels))
        }
        predictions = predictions.rename(columns=rename_columns)
        predictions["mean"] = predictions["0.5"]
        predictions = TimeSeriesDataFrame(predictions)
        return predictions

为了方便起见,这里概述了不同方法的输入和输出的主要限制。

  • _fit_predict 方法接收的输入数据满足:

    • 索引按 (item_id, timestamp) 排序

    • 观测值的时间戳具有与 self.freq 对应的规则频率

    • self.target 包含时间序列的目标值

    • 目标列可能包含由 NaN 表示的缺失值

    • 数据可能包含协变量(包括静态特征),其 schema 在 self.covariate_metadata 中描述

      • 实值协变量的数据类型为 float32

      • 分类协变量的数据类型为 category

      • 协变量不包含任何缺失值

    • 静态特征(如果存在)可通过 data.static_features 获取

  • _predict 返回的预测结果必须满足:

    • TimeSeriesDataFrame 对象返回预测结果

    • 预测结果包含列 ["mean"] + [str(q) for q in self.quantile_levels],分别包含点预测和分位数预测结果

    • 预测结果的索引恰好包含 data 中每个时间序列的 self.prediction_length 个未来时间步长

    • 预测时间戳的频率与 self.freq 匹配

    • 预测结果的索引按 (item_id, timestamp) 排序

    • 预测结果不包含由 NaN 表示的缺失值,也没有间隙

  • 如果提供了 time_limit,则 _fit 方法的运行时长不应超过 time_limit 秒。

  • 任何方法都不应就地修改数据。如果需要修改,请先创建数据的副本。

  • 即使某些时间序列全部为 NaN 或只有一个观测值,所有方法也应能正常工作。


我们现在将以两种模式使用此包装器:

  1. 独立模式(在 TimeSeriesPredictor 外部)。

    • 此模式应用于开发和调试。在这种情况下,我们需要手动处理预处理和模型配置。

  2. TimeSeriesPredictor 内部。

    • 此模式可轻松将自定义模型与 AutoGluon 中可用的其他模型结合和比较。编写自定义模型包装器的主要目的是在这种模式下使用它。

加载和预处理数据

首先,我们加载将用于开发和评估的 Grocery Sales 数据集。

from autogluon.timeseries import TimeSeriesDataFrame

raw_data = TimeSeriesDataFrame.from_path(
    "https://autogluon.s3.amazonaws.com/datasets/timeseries/grocery_sales/test.csv",
    static_features_path="https://autogluon.s3.amazonaws.com/datasets/timeseries/grocery_sales/static.csv",
)
raw_data.head()
scaled_price promotion_email promotion_homepage unit_sales
item_id timestamp
1062_101 2018-01-01 0.879130 0.0 0.0 636.0
2018-01-08 0.994517 0.0 0.0 123.0
2018-01-15 1.005513 0.0 0.0 391.0
2018-01-22 1.000000 0.0 0.0 339.0
2018-01-29 0.883309 0.0 0.0 661.0
raw_data.static_features.head()
product_code product_category product_subcategory location_code
item_id
1062_101 1062 饮料 芒果汁 101
1062_102 1062 饮料 芒果汁 102
1062_104 1062 饮料 芒果汁 104
1062_106 1062 饮料 芒果汁 106
1062_108 1062 饮料 芒果汁 108
print("Types of the columns in raw data:")
print(raw_data.dtypes)
print("\nTypes of the columns in raw static features:")
print(raw_data.static_features.dtypes)

print("\nNumber of missing values per column:")
print(raw_data.isna().sum())
Types of the columns in raw data:
scaled_price          float64
promotion_email       float64
promotion_homepage    float64
unit_sales            float64
dtype: object

Types of the columns in raw static features:
product_code            int64
product_category       object
product_subcategory    object
location_code           int64
dtype: object

Number of missing values per column:
scaled_price          714
promotion_email       714
promotion_homepage    714
unit_sales            714
dtype: int64

定义预测任务

prediction_length = 7  # number of future steps to predict
target = "unit_sales"  # target column
known_covariates_names = ["promotion_email", "promotion_homepage"]  # covariates known in the future

在独立模式下使用模型之前,我们需要对数据应用通用的 AutoGluon 预处理。

TimeSeriesFeatureGenerator 包含数据类型标准化和协变量缺失值插补等预处理步骤。

from autogluon.timeseries.utils.features import TimeSeriesFeatureGenerator

feature_generator = TimeSeriesFeatureGenerator(target=target, known_covariates_names=known_covariates_names)
data = feature_generator.fit_transform(raw_data)
print("Types of the columns in preprocessed data:")
print(data.dtypes)
print("\nTypes of the columns in preprocessed static features:")
print(data.static_features.dtypes)

print("\nNumber of missing values per column:")
print(data.isna().sum())
Types of the columns in preprocessed data:
unit_sales            float64
promotion_email       float32
promotion_homepage    float32
scaled_price          float32
dtype: object

Types of the columns in preprocessed static features:
product_category       category
product_subcategory    category
product_code            float32
location_code           float32
dtype: object

Number of missing values per column:
unit_sales            714
promotion_email         0
promotion_homepage      0
scaled_price            0
dtype: int64

在独立模式下使用自定义模型

在独立模式下使用模型有助于调试我们的实现。一旦确保所有方法都按预期工作,我们将在 TimeSeriesPredictor 内部使用模型。

训练

现在我们准备好在预处理后的数据上训练自定义模型了。

在独立模式下使用模型时,我们需要手动配置其参数。

model = NHITSModel(
    prediction_length=prediction_length,
    target=target,
    covariate_metadata=feature_generator.covariate_metadata,
    freq=data.freq,
    quantile_levels=[0.1, 0.5, 0.9],
)
model.fit(train_data=data, time_limit=20)
Entering the `_fit` method
Hyperparameters:
{'loss': MQLoss(),
 'input_size': 14,
 'scaler_type': 'standard',
 'enable_progress_bar': False,
 'enable_model_summary': False,
 'logger': False,
 'accelerator': 'gpu',
 'start_padding_enabled': True,
 'futr_exog_list': ['promotion_email', 'promotion_homepage'],
 'hist_exog_list': ['scaled_price'],
 'stat_exog_list': ['product_code', 'location_code'],
 'devices': 1,
 'max_time': {'seconds': 13.536751490100324}}
Exiting the `_fit` method
NHITS

预测和评分

past_data, known_covariates = data.get_model_inputs_for_scoring(
    prediction_length=prediction_length,
    known_covariates_names=known_covariates_names,
)
predictions = model.predict(past_data, known_covariates)
predictions.head()
Entering the `_predict` method
0.1 0.5 0.9 mean
item_id timestamp
1062_101 2018-06-18 216.337616 340.998993 484.123535 340.998993
2018-06-25 204.024643 352.983673 524.839233 352.983673
2018-07-02 209.871796 364.116669 548.581787 364.116669
2018-07-09 205.984467 364.956635 551.808472 364.956635
2018-07-16 208.200226 364.006592 548.709839 364.006592
model.score(data)
Entering the `_predict` method
np.float64(-0.3316095466210152)

TimeSeriesPredictor 内部使用自定义模型

确保自定义模型在独立模式下正常工作后,我们可以将其与其他模型一起传递给 TimeSeriesPredictor。

from autogluon.timeseries import TimeSeriesPredictor

train_data, test_data = raw_data.train_test_split(prediction_length)

predictor = TimeSeriesPredictor(
    prediction_length=prediction_length,
    target=target,
    known_covariates_names=known_covariates_names,
)

predictor.fit(
    train_data,
    hyperparameters={
        "Naive": {},
        "Chronos": {"model_path": "bolt_small"},
        "ETS": {},
        NHITSModel: {},
    },
    time_limit=120,
)
Beginning AutoGluon training... Time limit = 120s
AutoGluon will save models to '/home/ci/autogluon/docs/tutorials/timeseries/advanced/AutogluonModels/ag-20250508_204919'
=================== System Info ===================
AutoGluon Version:  1.3.1b20250508
Python Version:     3.11.9
Operating System:   Linux
Platform Machine:   x86_64
Platform Version:   #1 SMP Wed Mar 12 14:53:59 UTC 2025
CPU Count:          8
GPU Count:          1
Memory Avail:       27.81 GB / 30.95 GB (89.8%)
Disk Space Avail:   212.13 GB / 255.99 GB (82.9%)
===================================================

Fitting with arguments:
{'enable_ensemble': True,
 'eval_metric': WQL,
 'hyperparameters': {<class '__main__.NHITSModel'>: {},
                     'Chronos': {'model_path': 'bolt_small'},
                     'ETS': {},
                     'Naive': {}},
 'known_covariates_names': ['promotion_email', 'promotion_homepage'],
 'num_val_windows': 1,
 'prediction_length': 7,
 'quantile_levels': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],
 'random_seed': 123,
 'refit_every_n_windows': 1,
 'refit_full': False,
 'skip_model_selection': False,
 'target': 'unit_sales',
 'time_limit': 120,
 'verbosity': 2}
Inferred time series frequency: 'W-MON'
Provided train_data has 7656 rows (NaN fraction=6.8%), 319 time series. Median time series length is 24 (min=24, max=24).

Provided data contains following columns:
	target: 'unit_sales'
	known_covariates:
		categorical:        []
		continuous (float): ['promotion_email', 'promotion_homepage']
	past_covariates:
		categorical:        []
		continuous (float): ['scaled_price']
	static_features:
		categorical:        ['product_category', 'product_subcategory']
		continuous (float): ['product_code', 'location_code']

To learn how to fix incorrectly inferred types, please see documentation for TimeSeriesPredictor.fit

AutoGluon will gauge predictive performance using evaluation metric: 'WQL'
	This metric's sign has been flipped to adhere to being higher_is_better. The metric score can be multiplied by -1 to get the metric value.
===================================================

Starting training. Start time is 2025-05-08 20:49:19
Models that will be trained: ['Naive', 'ETS', 'Chronos[bolt_small]', 'NHITS']
Training timeseries model Naive. Training for up to 24.0s of the 119.9s of remaining time.
	-0.5412       = Validation score (-WQL)
	0.03    s     = Training runtime
	1.86    s     = Validation (prediction) runtime
Training timeseries model ETS. Training for up to 29.5s of the 118.0s of remaining time.
	-0.7039       = Validation score (-WQL)
	0.04    s     = Training runtime
	0.85    s     = Validation (prediction) runtime
Training timeseries model Chronos[bolt_small]. Training for up to 39.0s of the 117.1s of remaining time.
	-0.3320       = Validation score (-WQL)
	0.59    s     = Training runtime
	1.41    s     = Validation (prediction) runtime
Training timeseries model NHITS. Training for up to 57.6s of the 115.1s of remaining time.
	-0.4681       = Validation score (-WQL)
	19.03   s     = Training runtime
	0.10    s     = Validation (prediction) runtime
Fitting simple weighted ensemble.
	Ensemble weights: {'Chronos[bolt_small]': np.float64(0.97), 'NHITS': np.float64(0.03)}
/home/ci/autogluon/timeseries/src/autogluon/timeseries/metrics/abstract.py:101: FutureWarning: Passing `prediction_length` to `TimeSeriesScorer.__call__` is deprecated and will be removed in v2.0. Please set the `eval_metric.prediction_length` attribute instead.
  warnings.warn(
	-0.3320       = Validation score (-WQL)
	0.52    s     = Training runtime
	1.51    s     = Validation (prediction) runtime
Training complete. Models trained: ['Naive', 'ETS', 'Chronos[bolt_small]', 'NHITS', 'WeightedEnsemble']
Total runtime: 24.50 s
Best model: Chronos[bolt_small]
Best model score: -0.3320
Entering the `_fit` method
Hyperparameters:
{'loss': MQLoss(),
 'input_size': 14,
 'scaler_type': 'standard',
 'enable_progress_bar': False,
 'enable_model_summary': False,
 'logger': False,
 'accelerator': 'gpu',
 'start_padding_enabled': True,
 'futr_exog_list': ['promotion_email', 'promotion_homepage'],
 'hist_exog_list': ['scaled_price'],
 'stat_exog_list': ['product_code', 'location_code'],
 'devices': 1,
 'max_time': {'seconds': 51.80114328768496}}
Exiting the `_fit` method
Entering the `_predict` method
<autogluon.timeseries.predictor.TimeSeriesPredictor at 0x7f7e582a6650>

请注意,当我们在预测器内部使用自定义模型时,无需担心以下事项:

  • 手动配置模型(设置 freq, prediction_length

  • 使用 TimeSeriesFeatureGenerator 预处理数据

  • 设置时间限制

TimeSeriesPredictor 会自动处理上述所有方面。

我们还可以轻松地将自定义模型与预测器训练的其他模型进行比较。

predictor.leaderboard(test_data)
Entering the `_predict` method
Additional data provided, testing on additional data. Resulting leaderboard will be sorted according to test score (`score_test`).
/home/ci/autogluon/timeseries/src/autogluon/timeseries/metrics/abstract.py:101: FutureWarning: Passing `prediction_length` to `TimeSeriesScorer.__call__` is deprecated and will be removed in v2.0. Please set the `eval_metric.prediction_length` attribute instead.
  warnings.warn(
/home/ci/autogluon/timeseries/src/autogluon/timeseries/metrics/abstract.py:101: FutureWarning: Passing `prediction_length` to `TimeSeriesScorer.__call__` is deprecated and will be removed in v2.0. Please set the `eval_metric.prediction_length` attribute instead.
  warnings.warn(
/home/ci/autogluon/timeseries/src/autogluon/timeseries/metrics/abstract.py:101: FutureWarning: Passing `prediction_length` to `TimeSeriesScorer.__call__` is deprecated and will be removed in v2.0. Please set the `eval_metric.prediction_length` attribute instead.
  warnings.warn(
/home/ci/autogluon/timeseries/src/autogluon/timeseries/metrics/abstract.py:101: FutureWarning: Passing `prediction_length` to `TimeSeriesScorer.__call__` is deprecated and will be removed in v2.0. Please set the `eval_metric.prediction_length` attribute instead.
  warnings.warn(
/home/ci/autogluon/timeseries/src/autogluon/timeseries/metrics/abstract.py:101: FutureWarning: Passing `prediction_length` to `TimeSeriesScorer.__call__` is deprecated and will be removed in v2.0. Please set the `eval_metric.prediction_length` attribute instead.
  warnings.warn(
模型 score_test score_val pred_time_test pred_time_val fit_time_marginal fit_order
0 WeightedEnsemble -0.314278 -0.332025 0.666137 1.507256 0.515570 5
1 Chronos[bolt_small] -0.315786 -0.331984 0.531369 1.405814 0.590614 3
2 NHITS -0.397008 -0.468127 0.132401 0.101442 19.026050 4
3 ETS -0.459021 -0.703868 0.250249 0.845771 0.036806 2
4 Naive -0.512205 -0.541231 0.188219 1.860332 0.033384 1

我们还可以利用预测器的其他功能,例如 feature_importance

predictor.feature_importance(test_data, model="NHITS")
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Computing feature importance
/home/ci/autogluon/timeseries/src/autogluon/timeseries/metrics/abstract.py:101: FutureWarning: Passing `prediction_length` to `TimeSeriesScorer.__call__` is deprecated and will be removed in v2.0. Please set the `eval_metric.prediction_length` attribute instead.
  warnings.warn(
/home/ci/autogluon/timeseries/src/autogluon/timeseries/metrics/abstract.py:101: FutureWarning: Passing `prediction_length` to `TimeSeriesScorer.__call__` is deprecated and will be removed in v2.0. Please set the `eval_metric.prediction_length` attribute instead.
  warnings.warn(
/home/ci/autogluon/timeseries/src/autogluon/timeseries/metrics/abstract.py:101: FutureWarning: Passing `prediction_length` to `TimeSeriesScorer.__call__` is deprecated and will be removed in v2.0. Please set the `eval_metric.prediction_length` attribute instead.
  warnings.warn(
/home/ci/autogluon/timeseries/src/autogluon/timeseries/metrics/abstract.py:101: FutureWarning: Passing `prediction_length` to `TimeSeriesScorer.__call__` is deprecated and will be removed in v2.0. Please set the `eval_metric.prediction_length` attribute instead.
  warnings.warn(
/home/ci/autogluon/timeseries/src/autogluon/timeseries/metrics/abstract.py:101: FutureWarning: Passing `prediction_length` to `TimeSeriesScorer.__call__` is deprecated and will be removed in v2.0. Please set the `eval_metric.prediction_length` attribute instead.
  warnings.warn(
/home/ci/autogluon/timeseries/src/autogluon/timeseries/metrics/abstract.py:101: FutureWarning: Passing `prediction_length` to `TimeSeriesScorer.__call__` is deprecated and will be removed in v2.0. Please set the `eval_metric.prediction_length` attribute instead.
  warnings.warn(
/home/ci/autogluon/timeseries/src/autogluon/timeseries/metrics/abstract.py:101: FutureWarning: Passing `prediction_length` to `TimeSeriesScorer.__call__` is deprecated and will be removed in v2.0. Please set the `eval_metric.prediction_length` attribute instead.
  warnings.warn(
/home/ci/autogluon/timeseries/src/autogluon/timeseries/metrics/abstract.py:101: FutureWarning: Passing `prediction_length` to `TimeSeriesScorer.__call__` is deprecated and will be removed in v2.0. Please set the `eval_metric.prediction_length` attribute instead.
  warnings.warn(
/home/ci/autogluon/timeseries/src/autogluon/timeseries/metrics/abstract.py:101: FutureWarning: Passing `prediction_length` to `TimeSeriesScorer.__call__` is deprecated and will be removed in v2.0. Please set the `eval_metric.prediction_length` attribute instead.
  warnings.warn(
/home/ci/autogluon/timeseries/src/autogluon/timeseries/metrics/abstract.py:101: FutureWarning: Passing `prediction_length` to `TimeSeriesScorer.__call__` is deprecated and will be removed in v2.0. Please set the `eval_metric.prediction_length` attribute instead.
  warnings.warn(
/home/ci/autogluon/timeseries/src/autogluon/timeseries/metrics/abstract.py:101: FutureWarning: Passing `prediction_length` to `TimeSeriesScorer.__call__` is deprecated and will be removed in v2.0. Please set the `eval_metric.prediction_length` attribute instead.
  warnings.warn(
/home/ci/autogluon/timeseries/src/autogluon/timeseries/metrics/abstract.py:101: FutureWarning: Passing `prediction_length` to `TimeSeriesScorer.__call__` is deprecated and will be removed in v2.0. Please set the `eval_metric.prediction_length` attribute instead.
  warnings.warn(
/home/ci/autogluon/timeseries/src/autogluon/timeseries/metrics/abstract.py:101: FutureWarning: Passing `prediction_length` to `TimeSeriesScorer.__call__` is deprecated and will be removed in v2.0. Please set the `eval_metric.prediction_length` attribute instead.
  warnings.warn(
/home/ci/autogluon/timeseries/src/autogluon/timeseries/metrics/abstract.py:101: FutureWarning: Passing `prediction_length` to `TimeSeriesScorer.__call__` is deprecated and will be removed in v2.0. Please set the `eval_metric.prediction_length` attribute instead.
  warnings.warn(
/home/ci/autogluon/timeseries/src/autogluon/timeseries/metrics/abstract.py:101: FutureWarning: Passing `prediction_length` to `TimeSeriesScorer.__call__` is deprecated and will be removed in v2.0. Please set the `eval_metric.prediction_length` attribute instead.
  warnings.warn(
/home/ci/autogluon/timeseries/src/autogluon/timeseries/metrics/abstract.py:101: FutureWarning: Passing `prediction_length` to `TimeSeriesScorer.__call__` is deprecated and will be removed in v2.0. Please set the `eval_metric.prediction_length` attribute instead.
  warnings.warn(
/home/ci/autogluon/timeseries/src/autogluon/timeseries/metrics/abstract.py:101: FutureWarning: Passing `prediction_length` to `TimeSeriesScorer.__call__` is deprecated and will be removed in v2.0. Please set the `eval_metric.prediction_length` attribute instead.
  warnings.warn(
/home/ci/autogluon/timeseries/src/autogluon/timeseries/metrics/abstract.py:101: FutureWarning: Passing `prediction_length` to `TimeSeriesScorer.__call__` is deprecated and will be removed in v2.0. Please set the `eval_metric.prediction_length` attribute instead.
  warnings.warn(
/home/ci/autogluon/timeseries/src/autogluon/timeseries/metrics/abstract.py:101: FutureWarning: Passing `prediction_length` to `TimeSeriesScorer.__call__` is deprecated and will be removed in v2.0. Please set the `eval_metric.prediction_length` attribute instead.
  warnings.warn(
/home/ci/autogluon/timeseries/src/autogluon/timeseries/metrics/abstract.py:101: FutureWarning: Passing `prediction_length` to `TimeSeriesScorer.__call__` is deprecated and will be removed in v2.0. Please set the `eval_metric.prediction_length` attribute instead.
  warnings.warn(
/home/ci/autogluon/timeseries/src/autogluon/timeseries/metrics/abstract.py:101: FutureWarning: Passing `prediction_length` to `TimeSeriesScorer.__call__` is deprecated and will be removed in v2.0. Please set the `eval_metric.prediction_length` attribute instead.
  warnings.warn(
/home/ci/autogluon/timeseries/src/autogluon/timeseries/metrics/abstract.py:101: FutureWarning: Passing `prediction_length` to `TimeSeriesScorer.__call__` is deprecated and will be removed in v2.0. Please set the `eval_metric.prediction_length` attribute instead.
  warnings.warn(
/home/ci/autogluon/timeseries/src/autogluon/timeseries/metrics/abstract.py:101: FutureWarning: Passing `prediction_length` to `TimeSeriesScorer.__call__` is deprecated and will be removed in v2.0. Please set the `eval_metric.prediction_length` attribute instead.
  warnings.warn(
/home/ci/autogluon/timeseries/src/autogluon/timeseries/metrics/abstract.py:101: FutureWarning: Passing `prediction_length` to `TimeSeriesScorer.__call__` is deprecated and will be removed in v2.0. Please set the `eval_metric.prediction_length` attribute instead.
  warnings.warn(
/home/ci/autogluon/timeseries/src/autogluon/timeseries/metrics/abstract.py:101: FutureWarning: Passing `prediction_length` to `TimeSeriesScorer.__call__` is deprecated and will be removed in v2.0. Please set the `eval_metric.prediction_length` attribute instead.
  warnings.warn(
/home/ci/autogluon/timeseries/src/autogluon/timeseries/metrics/abstract.py:101: FutureWarning: Passing `prediction_length` to `TimeSeriesScorer.__call__` is deprecated and will be removed in v2.0. Please set the `eval_metric.prediction_length` attribute instead.
  warnings.warn(
/home/ci/autogluon/timeseries/src/autogluon/timeseries/metrics/abstract.py:101: FutureWarning: Passing `prediction_length` to `TimeSeriesScorer.__call__` is deprecated and will be removed in v2.0. Please set the `eval_metric.prediction_length` attribute instead.
  warnings.warn(
/home/ci/autogluon/timeseries/src/autogluon/timeseries/metrics/abstract.py:101: FutureWarning: Passing `prediction_length` to `TimeSeriesScorer.__call__` is deprecated and will be removed in v2.0. Please set the `eval_metric.prediction_length` attribute instead.
  warnings.warn(
/home/ci/autogluon/timeseries/src/autogluon/timeseries/metrics/abstract.py:101: FutureWarning: Passing `prediction_length` to `TimeSeriesScorer.__call__` is deprecated and will be removed in v2.0. Please set the `eval_metric.prediction_length` attribute instead.
  warnings.warn(
/home/ci/autogluon/timeseries/src/autogluon/timeseries/metrics/abstract.py:101: FutureWarning: Passing `prediction_length` to `TimeSeriesScorer.__call__` is deprecated and will be removed in v2.0. Please set the `eval_metric.prediction_length` attribute instead.
  warnings.warn(
/home/ci/autogluon/timeseries/src/autogluon/timeseries/metrics/abstract.py:101: FutureWarning: Passing `prediction_length` to `TimeSeriesScorer.__call__` is deprecated and will be removed in v2.0. Please set the `eval_metric.prediction_length` attribute instead.
  warnings.warn(
/home/ci/autogluon/timeseries/src/autogluon/timeseries/metrics/abstract.py:101: FutureWarning: Passing `prediction_length` to `TimeSeriesScorer.__call__` is deprecated and will be removed in v2.0. Please set the `eval_metric.prediction_length` attribute instead.
  warnings.warn(
/home/ci/autogluon/timeseries/src/autogluon/timeseries/metrics/abstract.py:101: FutureWarning: Passing `prediction_length` to `TimeSeriesScorer.__call__` is deprecated and will be removed in v2.0. Please set the `eval_metric.prediction_length` attribute instead.
  warnings.warn(
/home/ci/autogluon/timeseries/src/autogluon/timeseries/metrics/abstract.py:101: FutureWarning: Passing `prediction_length` to `TimeSeriesScorer.__call__` is deprecated and will be removed in v2.0. Please set the `eval_metric.prediction_length` attribute instead.
  warnings.warn(
/home/ci/autogluon/timeseries/src/autogluon/timeseries/metrics/abstract.py:101: FutureWarning: Passing `prediction_length` to `TimeSeriesScorer.__call__` is deprecated and will be removed in v2.0. Please set the `eval_metric.prediction_length` attribute instead.
  warnings.warn(
/home/ci/autogluon/timeseries/src/autogluon/timeseries/metrics/abstract.py:101: FutureWarning: Passing `prediction_length` to `TimeSeriesScorer.__call__` is deprecated and will be removed in v2.0. Please set the `eval_metric.prediction_length` attribute instead.
  warnings.warn(
/home/ci/autogluon/timeseries/src/autogluon/timeseries/metrics/abstract.py:101: FutureWarning: Passing `prediction_length` to `TimeSeriesScorer.__call__` is deprecated and will be removed in v2.0. Please set the `eval_metric.prediction_length` attribute instead.
  warnings.warn(
/home/ci/autogluon/timeseries/src/autogluon/timeseries/metrics/abstract.py:101: FutureWarning: Passing `prediction_length` to `TimeSeriesScorer.__call__` is deprecated and will be removed in v2.0. Please set the `eval_metric.prediction_length` attribute instead.
  warnings.warn(
/home/ci/autogluon/timeseries/src/autogluon/timeseries/metrics/abstract.py:101: FutureWarning: Passing `prediction_length` to `TimeSeriesScorer.__call__` is deprecated and will be removed in v2.0. Please set the `eval_metric.prediction_length` attribute instead.
  warnings.warn(
/home/ci/autogluon/timeseries/src/autogluon/timeseries/metrics/abstract.py:101: FutureWarning: Passing `prediction_length` to `TimeSeriesScorer.__call__` is deprecated and will be removed in v2.0. Please set the `eval_metric.prediction_length` attribute instead.
  warnings.warn(
重要性 标准差 (stdev) 样本数 (n) p99_low (99% 置信区间下限) p99_high (99% 置信区间上限)
product_category 0.000000 0.000000 5.0 0.000000 0.000000
product_subcategory 0.000000 0.000000 5.0 0.000000 0.000000
product_code -0.000775 0.010627 5.0 -0.022656 0.021106
location_code 0.000092 0.000285 5.0 -0.000494 0.000679
promotion_email 0.007981 0.009743 5.0 -0.012079 0.028041
promotion_homepage 0.005051 0.007467 5.0 -0.010324 0.020426
scaled_price -0.000088 0.000898 5.0 -0.001937 0.001761

正如预期的那样,特征 product_categoryproduct_subcategory 的重要性为零,因为我们的实现忽略了分类特征。


下面介绍如何使用不同的超参数配置训练自定义模型的多个版本

predictor = TimeSeriesPredictor(
    prediction_length=prediction_length,
    target=target,
    known_covariates_names=known_covariates_names,
)
predictor.fit(
    train_data,
    hyperparameters={
        NHITSModel: [
            {},  # default hyperparameters
            {"input_size": 20},  # custom input_size
            {"scaler_type": "robust"},  # custom scaler_type
        ]
    },
    time_limit=60,
)
Entering the `_fit` method
Hyperparameters:
{'loss': MQLoss(),
 'input_size': 14,
 'scaler_type': 'standard',
 'enable_progress_bar': False,
 'enable_model_summary': False,
 'logger': False,
 'accelerator': 'gpu',
 'start_padding_enabled': True,
 'futr_exog_list': ['promotion_email', 'promotion_homepage'],
 'hist_exog_list': ['scaled_price'],
 'stat_exog_list': ['product_code', 'location_code'],
 'devices': 1,
 'max_time': {'seconds': 13.4735950648168}}
Exiting the `_fit` method
Entering the `_predict` method
Entering the `_fit` method
Hyperparameters:
{'loss': MQLoss(),
 'input_size': 20,
 'scaler_type': 'standard',
 'enable_progress_bar': False,
 'enable_model_summary': False,
 'logger': False,
 'accelerator': 'gpu',
 'start_padding_enabled': True,
 'futr_exog_list': ['promotion_email', 'promotion_homepage'],
 'hist_exog_list': ['scaled_price'],
 'stat_exog_list': ['product_code', 'location_code'],
 'devices': 1,
 'max_time': {'seconds': 13.842486143193675}}
Exiting the `_fit` method
Entering the `_predict` method
Entering the `_fit` method
Hyperparameters:
{'loss': MQLoss(),
 'input_size': 14,
 'scaler_type': 'robust',
 'enable_progress_bar': False,
 'enable_model_summary': False,
 'logger': False,
 'accelerator': 'gpu',
 'start_padding_enabled': True,
 'futr_exog_list': ['promotion_email', 'promotion_homepage'],
 'hist_exog_list': ['scaled_price'],
 'stat_exog_list': ['product_code', 'location_code'],
 'devices': 1,
 'max_time': {'seconds': 14.42103596553934}}
Exiting the `_fit` method
Entering the `_predict` method
Beginning AutoGluon training... Time limit = 60s
AutoGluon will save models to '/home/ci/autogluon/docs/tutorials/timeseries/advanced/AutogluonModels/ag-20250508_204948'
=================== System Info ===================
AutoGluon Version:  1.3.1b20250508
Python Version:     3.11.9
Operating System:   Linux
Platform Machine:   x86_64
Platform Version:   #1 SMP Wed Mar 12 14:53:59 UTC 2025
CPU Count:          8
GPU Count:          1
Memory Avail:       26.97 GB / 30.95 GB (87.2%)
Disk Space Avail:   211.94 GB / 255.99 GB (82.8%)
===================================================

Fitting with arguments:
{'enable_ensemble': True,
 'eval_metric': WQL,
 'hyperparameters': {<class '__main__.NHITSModel'>: [{},
                                                     {'input_size': 20},
                                                     {'scaler_type': 'robust'}]},
 'known_covariates_names': ['promotion_email', 'promotion_homepage'],
 'num_val_windows': 1,
 'prediction_length': 7,
 'quantile_levels': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],
 'random_seed': 123,
 'refit_every_n_windows': 1,
 'refit_full': False,
 'skip_model_selection': False,
 'target': 'unit_sales',
 'time_limit': 60,
 'verbosity': 2}
Inferred time series frequency: 'W-MON'
Provided train_data has 7656 rows (NaN fraction=6.8%), 319 time series. Median time series length is 24 (min=24, max=24).

Provided data contains following columns:
	target: 'unit_sales'
	known_covariates:
		categorical:        []
		continuous (float): ['promotion_email', 'promotion_homepage']
	past_covariates:
		categorical:        []
		continuous (float): ['scaled_price']
	static_features:
		categorical:        ['product_category', 'product_subcategory']
		continuous (float): ['product_code', 'location_code']

To learn how to fix incorrectly inferred types, please see documentation for TimeSeriesPredictor.fit

AutoGluon will gauge predictive performance using evaluation metric: 'WQL'
	This metric's sign has been flipped to adhere to being higher_is_better. The metric score can be multiplied by -1 to get the metric value.
===================================================

Starting training. Start time is 2025-05-08 20:49:48
Models that will be trained: ['NHITS', 'NHITS_2', 'NHITS_3']
Training timeseries model NHITS. Training for up to 15.0s of the 59.9s of remaining time.
	-0.4648       = Validation score (-WQL)
	13.62   s     = Training runtime
	0.10    s     = Validation (prediction) runtime
Training timeseries model NHITS_2. Training for up to 15.4s of the 46.2s of remaining time.
	-0.5210       = Validation score (-WQL)
	13.98   s     = Training runtime
	0.10    s     = Validation (prediction) runtime
Training timeseries model NHITS_3. Training for up to 16.0s of the 32.1s of remaining time.
	-0.3959       = Validation score (-WQL)
	14.57   s     = Training runtime
	0.10    s     = Validation (prediction) runtime
Fitting simple weighted ensemble.
	Ensemble weights: {'NHITS_3': np.float64(1.0)}
/home/ci/autogluon/timeseries/src/autogluon/timeseries/metrics/abstract.py:101: FutureWarning: Passing `prediction_length` to `TimeSeriesScorer.__call__` is deprecated and will be removed in v2.0. Please set the `eval_metric.prediction_length` attribute instead.
  warnings.warn(
	-0.3959       = Validation score (-WQL)
	0.39    s     = Training runtime
	0.10    s     = Validation (prediction) runtime
Training complete. Models trained: ['NHITS', 'NHITS_2', 'NHITS_3', 'WeightedEnsemble']
Total runtime: 42.99 s
Best model: NHITS_3
Best model score: -0.3959
<autogluon.timeseries.predictor.TimeSeriesPredictor at 0x7f7e581c4c10>
predictor.leaderboard(test_data)
Entering the `_predict` method
Entering the `_predict` method
Entering the `_predict` method
Additional data provided, testing on additional data. Resulting leaderboard will be sorted according to test score (`score_test`).
/home/ci/autogluon/timeseries/src/autogluon/timeseries/metrics/abstract.py:101: FutureWarning: Passing `prediction_length` to `TimeSeriesScorer.__call__` is deprecated and will be removed in v2.0. Please set the `eval_metric.prediction_length` attribute instead.
  warnings.warn(
/home/ci/autogluon/timeseries/src/autogluon/timeseries/metrics/abstract.py:101: FutureWarning: Passing `prediction_length` to `TimeSeriesScorer.__call__` is deprecated and will be removed in v2.0. Please set the `eval_metric.prediction_length` attribute instead.
  warnings.warn(
/home/ci/autogluon/timeseries/src/autogluon/timeseries/metrics/abstract.py:101: FutureWarning: Passing `prediction_length` to `TimeSeriesScorer.__call__` is deprecated and will be removed in v2.0. Please set the `eval_metric.prediction_length` attribute instead.
  warnings.warn(
/home/ci/autogluon/timeseries/src/autogluon/timeseries/metrics/abstract.py:101: FutureWarning: Passing `prediction_length` to `TimeSeriesScorer.__call__` is deprecated and will be removed in v2.0. Please set the `eval_metric.prediction_length` attribute instead.
  warnings.warn(
模型 score_test score_val pred_time_test pred_time_val fit_time_marginal fit_order
0 NHITS -0.384574 -0.464803 0.118471 0.097582 13.620753 1
1 WeightedEnsemble -0.422423 -0.395856 0.125363 0.101789 0.394304 4
2 NHITS_3 -0.422423 -0.395856 0.124666 0.101789 14.568996 3
3 NHITS_2 -0.465447 -0.521006 0.120018 0.099501 13.980385 2

总结

这就是将自定义预测模型添加到 AutoGluon 所需的全部步骤。如果您创建了自定义模型,请考虑提交一个 PR,以便我们可以将其正式添加到 AutoGluon 中!

有关更多教程,请参阅时间序列预测 - 快速入门时间序列预测 - 深入探讨