autogluon.multimodal.MultiModalPredictor

class autogluon.multimodal.MultiModalPredictor(label: str | None = None, problem_type: str | None = None, query: str | List[str] | None = None, response: str | List[str] | None = None, match_label: int | str | None = None, presets: str | None = None, eval_metric: str | Scorer | None = None, hyperparameters: dict | None = None, path: str | None = None, verbosity: int | None = 2, num_classes: int | None = None, classes: list | None = None, warn_if_exist: bool | None = True, enable_progress_bar: bool | None = None, pretrained: bool | None = True, validation_metric: str | None = None, sample_data_path: str | None = None, use_ensemble: bool | None = False, ensemble_size: int | None = 2, ensemble_mode: str | None = 'one_shot')[source]

AutoMM 旨在通过仅仅三行代码简化基础模型在下游应用中的微调。AutoMM 与流行的模型库无缝集成,例如 HuggingFace TransformersTIMMMMDetection,支持多种数据模态,包括图像、文本、表格数据和文档数据,无论单独使用还是组合使用。它支持一系列任务,包括分类、回归、目标检测、命名实体识别、语义匹配和图像分割。

__init__(label: str | None = None, problem_type: str | None = None, query: str | List[str] | None = None, response: str | List[str] | None = None, match_label: int | str | None = None, presets: str | None = None, eval_metric: str | Scorer | None = None, hyperparameters: dict | None = None, path: str | None = None, verbosity: int | None = 2, num_classes: int | None = None, classes: list | None = None, warn_if_exist: bool | None = True, enable_progress_bar: bool | None = None, pretrained: bool | None = True, validation_metric: str | None = None, sample_data_path: str | None = None, use_ensemble: bool | None = False, ensemble_size: int | None = 2, ensemble_mode: str | None = 'one_shot')[source]
参数:
  • label – pd.DataFrame 中包含要预测的目标变量的列的名称。

  • problem_type

    问题类型。我们支持以下标准问题:

    • ’binary’:二元分类

    • ’multiclass’:多类别分类

    • ’regression’:回归

    • ’classification’:分类问题包括“二元分类”和“多类别分类”。

    此外,我们还支持以下高级问题:

    • ’object_detection’:目标检测

    • ’ner’ 或 ‘named_entity_recognition’:命名实体提取

    • ’text_similarity’:文本-文本语义匹配

    • ’image_similarity’:图像-图像语义匹配

    • ’image_text_similarity’:文本-图像语义匹配

    • ’feature_extraction’:特征提取(仅支持推理)

    • ’zero_shot_image_classification’:零样本图像分类(仅支持推理)

    • ’few_shot_classification’:用于图像或文本数据的少样本分类。

    • ’semantic_segmentation’:使用 Segment Anything Model 进行语义分割。

    对于某些问题类型,默认行为是根据预设/超参数加载预训练模型,并且预测器可以进行零样本推理(无需调用 .fit() 即可运行推理)。这些问题类型包括:

    • ’object_detection’

    • ’text_similarity’

    • ’image_similarity’

    • ’image_text_similarity’

    • ’feature_extraction’

    • ’zero_shot_image_classification’

  • query – pd.DataFrame 中包含语义匹配任务中查询数据的列的名称。

  • response – pd.DataFrame 中包含语义匹配任务中响应数据的列的名称。如果未提供标签列,则 pd.DataFrame 行中的查询和响应对被假定为正样本对。

  • match_label – 表示 <query, response> 对被计为“匹配”的标签类别。这在任务属于语义匹配且标签为二元时使用。例如,标签列在重复项检测任务中可能包含 ["duplicate", "not duplicate"]。match_label 应该是 "duplicate",因为它表示两个项匹配。

  • presets – 关于模型质量的预设,例如“best_quality”(最佳质量)、“high_quality”(高质量,默认)和“medium_quality”(中等质量)。每种质量都有其对应的 HPO 预设:“best_quality_hpo”、“high_quality_hpo”和“medium_quality_hpo”。

  • eval_metric – 评估指标名称。如果 eval_metric = None,则根据 problem_type 自动选择。多类别分类默认为“accuracy”,二元分类默认为 roc_auc,回归默认为“root_mean_squared_error”。

  • hyperparameters

    这用于覆盖一些默认配置。例如,可以通过以下格式更改文本和图像骨干网络:

    一个字符串 hyperparameters = “model.hf_text.checkpoint_name=google/electra-small-discriminator model.timm_image.checkpoint_name=swin_small_patch4_window7_224”

    或一个字符串列表 hyperparameters = [“model.hf_text.checkpoint_name=google/electra-small-discriminator”, “model.timm_image.checkpoint_name=swin_small_patch4_window7_224”]

    或一个字典 hyperparameters = {

    ”model.hf_text.checkpoint_name”: “google/electra-small-discriminator”, “model.timm_image.checkpoint_name”: “swin_small_patch4_window7_224”,

    }

  • path – 模型和相关文件应保存到的目录路径。如果未指定,将在工作目录中创建一个名为 “AutogluonAutoMM/ag-[TIMESTAMP]” 的带时间戳的文件夹。注意:要多次调用 fit() 并保存每次训练的所有结果,必须指定不同的 path 位置,或者完全不指定 path

  • verbosity – 详细程度级别范围从 0 到 4,控制打印多少日志信息。级别越高,打印的语句越详细。您可以设置 verbosity = 0 来抑制警告。

  • num_classes – 类别数量(用于目标检测)。如果指定了此参数且其与预训练模型的输出形状不同,模型的头部将被更改为具有 <num_classes> 个输出。

  • classes – 所有类别(用于目标检测)。

  • warn_if_exist – 如果指定的 path 已存在,是否发出警告(默认 True)。

  • enable_progress_bar – 是否显示进度条(默认 True)。如果设置了环境变量 os.environ["AUTOMM_DISABLE_PROGRESS_BAR"],则会禁用。

  • pretrained – 是否使用预训练权重初始化模型(默认 True)。如果为 False,则创建具有随机初始化的模型。

  • validation_metric – 用于在训练期间选择最佳模型和早停的验证指标。如果未提供,将根据问题类型自动选择。

  • sample_data_path – 样本数据的路径,可以从中推断用于目标检测的 num_classes 或 classes。

  • use_ensemble – 训练预测器时是否使用集成(默认 False)。目前,它仅适用于分类或回归任务的多模态数据(图像+文本、图像+表格、文本+表格、图像+文本+表格)。

  • ensemble_size – 集成池中模型数量的倍数(默认 2)。实际集成大小 = ensemble_size * 模型数量

  • ensemble_mode – 进行集成的模式:- one_shot:经典的集成选择 - sequential:每次通过最佳的下一个模型扩展模型库,迭代调用经典的集成选择。

方法

dump_model

将模型权重和配置保存到本地目录。

evaluate

在给定数据集上评估模型。

export_onnx

将此预测器的模型导出为 ONNX 文件。

extract_embedding

提取每个样本(即所提供 pd.DataFrame 数据中的一行)的特征。

fit

训练模型,根据数据表(标签)的其他列(特征)预测某一列。

fit_summary

输出 fit() 的训练摘要信息。

get_num_gpus

从配置中获取 GPU 数量。

list_supported_models

列出每种问题类型支持的模型。

load

从 path 指定的目录加载预测器对象。

optimize_for_inference

优化预测器的模型以进行推理。

predict

预测新数据的标签列值。

predict_proba

预测类别概率而非类别标签。

save

将此预测器保存到 path 指定的目录中的文件。

set_num_gpus

在配置中设置 GPU 数量。

set_verbosity

设置日志的详细程度。

属性

class_labels

类别标签的原始名称。

classes

目标检测问题类型的对象类别。

column_types

pd.DataFrame 中的列类型。

eval_metric

用于评估预测性能的指标。

label

pd.DataFrame 中包含要预测的目标变量的列的名称。

match_label

在语义匹配任务中,表示 <query, response> 对被计为“匹配”的标签类别。

model_size

返回模型的兆字节大小。

path

存储模型和相关文件的目录路径。

positive_class

将映射为 1 的类别标签的名称。

problem_property

问题的属性,存储问题类型及其相关属性。

problem_type

此预测器经过训练解决的预测问题类型。

query

pd.DataFrame 中包含语义匹配任务中查询数据的列的名称。

response

pd.DataFrame 中包含语义匹配任务中响应数据的列的名称。

total_parameters

模型参数的数量。

trainable_parameters

可训练模型参数的数量,通常指那些 requires_grad=True 的参数。

validation_metric

用于在训练期间选择最佳模型和早停的验证指标。

verbosity

详细程度级别范围从 0 到 4,控制打印信息的多少。