MultiModalPredictor.fit

MultiModalPredictor.fit(train_data: DataFrame | str, presets: str | None = None, tuning_data: DataFrame | str | None = None, max_num_tuning_data: int | None = None, id_mappings: Dict[str, Dict] | Dict[str, Series] | None = None, time_limit: int | None = None, save_path: str | None = None, hyperparameters: str | Dict | List[str] | None = None, column_types: dict | None = None, holdout_frac: float | None = None, teacher_predictor: str | MultiModalPredictor = None, seed: int | None = 0, standalone: bool | None = True, hyperparameter_tune_kwargs: dict | None = None, clean_ckpts: bool | None = True, predictions: List[ndarray] | None = None, labels: ndarray | None = None, predictors: List[str | MultiModalPredictor] | None = None)[source]

拟合模型以根据数据表的其他列(特征)预测某一列(标签)。

参数:
  • train_data – 包含训练数据的 pd.DataFrame。

  • presets – 关于模型质量的预设,例如 best_quality、high_quality 和 medium_quality。每种质量都有其对应的 HPO 预设:‘best_quality_hpo’、‘high_quality_hpo’ 和 ‘medium_quality_hpo’。

  • tuning_data – 包含验证数据的 pd.DataFrame,应与 train_data 具有相同的列。如果 tuning_data = Nonefit() 将自动从 train_data 中保留部分随机验证数据。

  • max_num_tuning_data – 用于调优的最大样本数(用于目标检测)。

  • id_mappings – ID 到内容的映射(用于语义匹配)。内容可以是文本、图像等。当 pd.DataFrame 包含查询/响应的标识符而不是其内容时使用。

  • time_limitfit() 应运行多长时间(壁钟时间,以秒为单位)。如果未指定,fit() 将一直运行直到模型完成训练。请注意,如果 use_ensemble=True,总运行时间将是 time_limit * N,其中 N 是集成模型中的模型数量。

  • save_path – 保存模型和工件的目录路径。

  • hyperparameters

    这用于覆盖一些默认配置。例如,可以通过以下格式更改文本和图像骨干网络:

    字符串格式 hyperparameters = “model.hf_text.checkpoint_name=google/electra-small-discriminator model.timm_image.checkpoint_name=swin_small_patch4_window7_224”

    或字符串列表格式 hyperparameters = [“model.hf_text.checkpoint_name=google/electra-small-discriminator”, “model.timm_image.checkpoint_name=swin_small_patch4_window7_224”]

    或字典格式 hyperparameters = {

    ”model.hf_text.checkpoint_name”: “google/electra-small-discriminator”, “model.timm_image.checkpoint_name”: “swin_small_patch4_window7_224”,

    }

  • column_types

    将列名映射到其数据类型的字典。例如:column_types = {“item_name”: “text”, “image”: “image_path”, “product_description”: “text”, “height”: “numerical”} 可用于具有列“item_name”、“brand”、“product_description”和“height”的表。如果为 None,将自动从数据中推断 column_types。当前支持的类型包括:

    • ”image_path”:此列中的每一行是一个图像路径。

    • ”text”:此列中的每一行包含文本(句子、段落等)。

    • ”numerical”:此列中的每一行包含一个数字。

    • ”categorical”:此列中的每一行属于 K 个类别之一。

  • holdout_frac – 从 train_data 中保留作为 tuning_data 的比例,用于优化超参数或早停(除非 tuning_data = None,否则忽略)。默认值(如果为 None)根据训练数据中的行数以及是否使用超参数优化来选择。

  • teacher_predictor – 预训练的教师预测器或其保存路径。如果提供,fit() 可以将其知识蒸馏给学生预测器,即当前预测器。

  • seed – 用于训练的随机种子(默认 0)。

  • standalone – 是否保存整个模型用于离线部署。

  • hyperparameter_tune_kwargs

    超参数调优策略和 kwargs(例如,要运行多少次 HPO 试验)。如果为 None,则不会执行超参数调优。

    num_trials: int

    要运行的 HPO 试验次数。需要指定 num_trialstime_limit

    scheduler: Union[str, ray.tune.schedulers.TrialScheduler]

    如果传入 str,AutoGluon 将使用一些默认参数为你创建调度器。如果传入 ray.tune.schedulers.TrialScheduler 对象,你需要负责初始化该对象。

    scheduler_init_args: Optional[dict] = None

    如果为 scheduler 提供 str,你可以选择为调度器提供自定义 init_args

    searcher: Union[str, ray.tune.search.SearchAlgorithm, ray.tune.search.Searcher]

    如果传入 str,AutoGluon 将使用一些默认参数为你创建搜索器。如果传入 ray.tune.schedulers.TrialScheduler 对象,你需要负责初始化该对象。你无需担心搜索器对象的 metricmode。AutoGluon 会自行处理。

    scheduler_init_args: Optional[dict] = None

    如果为 searcher 提供 str,你可以选择为搜索器提供自定义 init_args。你无需担心 metricmode。AutoGluon 会自行处理。

  • clean_ckpts – 是否在训练后清理中间检查点。

返回类型:

一个“MultiModalPredictor”对象(即自身)。