MultiModalPredictor.fit¶
- MultiModalPredictor.fit(train_data: DataFrame | str, presets: str | None = None, tuning_data: DataFrame | str | None = None, max_num_tuning_data: int | None = None, id_mappings: Dict[str, Dict] | Dict[str, Series] | None = None, time_limit: int | None = None, save_path: str | None = None, hyperparameters: str | Dict | List[str] | None = None, column_types: dict | None = None, holdout_frac: float | None = None, teacher_predictor: str | MultiModalPredictor = None, seed: int | None = 0, standalone: bool | None = True, hyperparameter_tune_kwargs: dict | None = None, clean_ckpts: bool | None = True, predictions: List[ndarray] | None = None, labels: ndarray | None = None, predictors: List[str | MultiModalPredictor] | None = None)[source]¶
拟合模型以根据数据表的其他列(特征)预测某一列(标签)。
- 参数:
train_data – 包含训练数据的 pd.DataFrame。
presets – 关于模型质量的预设,例如 best_quality、high_quality 和 medium_quality。每种质量都有其对应的 HPO 预设:‘best_quality_hpo’、‘high_quality_hpo’ 和 ‘medium_quality_hpo’。
tuning_data – 包含验证数据的 pd.DataFrame,应与 train_data 具有相同的列。如果 tuning_data = None,fit() 将自动从 train_data 中保留部分随机验证数据。
max_num_tuning_data – 用于调优的最大样本数(用于目标检测)。
id_mappings – ID 到内容的映射(用于语义匹配)。内容可以是文本、图像等。当 pd.DataFrame 包含查询/响应的标识符而不是其内容时使用。
time_limit – fit() 应运行多长时间(壁钟时间,以秒为单位)。如果未指定,fit() 将一直运行直到模型完成训练。请注意,如果 use_ensemble=True,总运行时间将是 time_limit * N,其中 N 是集成模型中的模型数量。
save_path – 保存模型和工件的目录路径。
hyperparameters –
这用于覆盖一些默认配置。例如,可以通过以下格式更改文本和图像骨干网络:
字符串格式 hyperparameters = “model.hf_text.checkpoint_name=google/electra-small-discriminator model.timm_image.checkpoint_name=swin_small_patch4_window7_224”
或字符串列表格式 hyperparameters = [“model.hf_text.checkpoint_name=google/electra-small-discriminator”, “model.timm_image.checkpoint_name=swin_small_patch4_window7_224”]
或字典格式 hyperparameters = {
”model.hf_text.checkpoint_name”: “google/electra-small-discriminator”, “model.timm_image.checkpoint_name”: “swin_small_patch4_window7_224”,
}
column_types –
将列名映射到其数据类型的字典。例如:column_types = {“item_name”: “text”, “image”: “image_path”, “product_description”: “text”, “height”: “numerical”} 可用于具有列“item_name”、“brand”、“product_description”和“height”的表。如果为 None,将自动从数据中推断 column_types。当前支持的类型包括:
”image_path”:此列中的每一行是一个图像路径。
”text”:此列中的每一行包含文本(句子、段落等)。
”numerical”:此列中的每一行包含一个数字。
”categorical”:此列中的每一行属于 K 个类别之一。
holdout_frac – 从 train_data 中保留作为 tuning_data 的比例,用于优化超参数或早停(除非 tuning_data = None,否则忽略)。默认值(如果为 None)根据训练数据中的行数以及是否使用超参数优化来选择。
teacher_predictor – 预训练的教师预测器或其保存路径。如果提供,fit() 可以将其知识蒸馏给学生预测器,即当前预测器。
seed – 用于训练的随机种子(默认 0)。
standalone – 是否保存整个模型用于离线部署。
hyperparameter_tune_kwargs –
超参数调优策略和 kwargs(例如,要运行多少次 HPO 试验)。如果为 None,则不会执行超参数调优。
- num_trials: int
要运行的 HPO 试验次数。需要指定 num_trials 或 time_limit。
- scheduler: Union[str, ray.tune.schedulers.TrialScheduler]
如果传入 str,AutoGluon 将使用一些默认参数为你创建调度器。如果传入 ray.tune.schedulers.TrialScheduler 对象,你需要负责初始化该对象。
- scheduler_init_args: Optional[dict] = None
如果为 scheduler 提供 str,你可以选择为调度器提供自定义 init_args
- searcher: Union[str, ray.tune.search.SearchAlgorithm, ray.tune.search.Searcher]
如果传入 str,AutoGluon 将使用一些默认参数为你创建搜索器。如果传入 ray.tune.schedulers.TrialScheduler 对象,你需要负责初始化该对象。你无需担心搜索器对象的 metric 和 mode。AutoGluon 会自行处理。
- scheduler_init_args: Optional[dict] = None
如果为 searcher 提供 str,你可以选择为搜索器提供自定义 init_args。你无需担心 metric 和 mode。AutoGluon 会自行处理。
clean_ckpts – 是否在训练后清理中间检查点。
- 返回类型:
一个“MultiModalPredictor”对象(即自身)。