autogluon.multimodal.MultiModalPredictor¶
- class autogluon.multimodal.MultiModalPredictor(label: str | None = None, problem_type: str | None = None, query: str | List[str] | None = None, response: str | List[str] | None = None, match_label: int | str | None = None, presets: str | None = None, eval_metric: str | Scorer | None = None, hyperparameters: dict | None = None, path: str | None = None, verbosity: int | None = 2, num_classes: int | None = None, classes: list | None = None, warn_if_exist: bool | None = True, enable_progress_bar: bool | None = None, pretrained: bool | None = True, validation_metric: str | None = None, sample_data_path: str | None = None, use_ensemble: bool | None = False, ensemble_size: int | None = 2, ensemble_mode: str | None = 'one_shot')[source]¶
AutoMM 旨在通过仅仅三行代码简化基础模型在下游应用中的微调。AutoMM 与流行的模型库无缝集成,例如 HuggingFace Transformers、TIMM 和 MMDetection,支持多种数据模态,包括图像、文本、表格数据和文档数据,无论单独使用还是组合使用。它支持一系列任务,包括分类、回归、目标检测、命名实体识别、语义匹配和图像分割。
- __init__(label: str | None = None, problem_type: str | None = None, query: str | List[str] | None = None, response: str | List[str] | None = None, match_label: int | str | None = None, presets: str | None = None, eval_metric: str | Scorer | None = None, hyperparameters: dict | None = None, path: str | None = None, verbosity: int | None = 2, num_classes: int | None = None, classes: list | None = None, warn_if_exist: bool | None = True, enable_progress_bar: bool | None = None, pretrained: bool | None = True, validation_metric: str | None = None, sample_data_path: str | None = None, use_ensemble: bool | None = False, ensemble_size: int | None = 2, ensemble_mode: str | None = 'one_shot')[source]¶
- 参数:
label – pd.DataFrame 中包含要预测的目标变量的列的名称。
problem_type –
问题类型。我们支持以下标准问题:
’binary’:二元分类
’multiclass’:多类别分类
’regression’:回归
’classification’:分类问题包括“二元分类”和“多类别分类”。
此外,我们还支持以下高级问题:
’object_detection’:目标检测
’ner’ 或 ‘named_entity_recognition’:命名实体提取
’text_similarity’:文本-文本语义匹配
’image_similarity’:图像-图像语义匹配
’image_text_similarity’:文本-图像语义匹配
’feature_extraction’:特征提取(仅支持推理)
’zero_shot_image_classification’:零样本图像分类(仅支持推理)
’few_shot_classification’:用于图像或文本数据的少样本分类。
’semantic_segmentation’:使用 Segment Anything Model 进行语义分割。
对于某些问题类型,默认行为是根据预设/超参数加载预训练模型,并且预测器可以进行零样本推理(无需调用 .fit() 即可运行推理)。这些问题类型包括:
’object_detection’
’text_similarity’
’image_similarity’
’image_text_similarity’
’feature_extraction’
’zero_shot_image_classification’
query – pd.DataFrame 中包含语义匹配任务中查询数据的列的名称。
response – pd.DataFrame 中包含语义匹配任务中响应数据的列的名称。如果未提供标签列,则 pd.DataFrame 行中的查询和响应对被假定为正样本对。
match_label – 表示 <query, response> 对被计为“匹配”的标签类别。这在任务属于语义匹配且标签为二元时使用。例如,标签列在重复项检测任务中可能包含 ["duplicate", "not duplicate"]。match_label 应该是 "duplicate",因为它表示两个项匹配。
presets – 关于模型质量的预设,例如“best_quality”(最佳质量)、“high_quality”(高质量,默认)和“medium_quality”(中等质量)。每种质量都有其对应的 HPO 预设:“best_quality_hpo”、“high_quality_hpo”和“medium_quality_hpo”。
eval_metric – 评估指标名称。如果 eval_metric = None,则根据 problem_type 自动选择。多类别分类默认为“accuracy”,二元分类默认为 roc_auc,回归默认为“root_mean_squared_error”。
hyperparameters –
这用于覆盖一些默认配置。例如,可以通过以下格式更改文本和图像骨干网络:
一个字符串 hyperparameters = “model.hf_text.checkpoint_name=google/electra-small-discriminator model.timm_image.checkpoint_name=swin_small_patch4_window7_224”
或一个字符串列表 hyperparameters = [“model.hf_text.checkpoint_name=google/electra-small-discriminator”, “model.timm_image.checkpoint_name=swin_small_patch4_window7_224”]
或一个字典 hyperparameters = {
”model.hf_text.checkpoint_name”: “google/electra-small-discriminator”, “model.timm_image.checkpoint_name”: “swin_small_patch4_window7_224”,
}
path – 模型和相关文件应保存到的目录路径。如果未指定,将在工作目录中创建一个名为 “AutogluonAutoMM/ag-[TIMESTAMP]” 的带时间戳的文件夹。注意:要多次调用 fit() 并保存每次训练的所有结果,必须指定不同的 path 位置,或者完全不指定 path。
verbosity – 详细程度级别范围从 0 到 4,控制打印多少日志信息。级别越高,打印的语句越详细。您可以设置 verbosity = 0 来抑制警告。
num_classes – 类别数量(用于目标检测)。如果指定了此参数且其与预训练模型的输出形状不同,模型的头部将被更改为具有 <num_classes> 个输出。
classes – 所有类别(用于目标检测)。
warn_if_exist – 如果指定的 path 已存在,是否发出警告(默认 True)。
enable_progress_bar – 是否显示进度条(默认 True)。如果设置了环境变量 os.environ["AUTOMM_DISABLE_PROGRESS_BAR"],则会禁用。
pretrained – 是否使用预训练权重初始化模型(默认 True)。如果为 False,则创建具有随机初始化的模型。
validation_metric – 用于在训练期间选择最佳模型和早停的验证指标。如果未提供,将根据问题类型自动选择。
sample_data_path – 样本数据的路径,可以从中推断用于目标检测的 num_classes 或 classes。
use_ensemble – 训练预测器时是否使用集成(默认 False)。目前,它仅适用于分类或回归任务的多模态数据(图像+文本、图像+表格、文本+表格、图像+文本+表格)。
ensemble_size – 集成池中模型数量的倍数(默认 2)。实际集成大小 = ensemble_size * 模型数量
ensemble_mode – 进行集成的模式:- one_shot:经典的集成选择 - sequential:每次通过最佳的下一个模型扩展模型库,迭代调用经典的集成选择。
方法
将模型权重和配置保存到本地目录。
在给定数据集上评估模型。
将此预测器的模型导出为 ONNX 文件。
提取每个样本(即所提供 pd.DataFrame 数据中的一行)的特征。
训练模型,根据数据表(标签)的其他列(特征)预测某一列。
输出 fit() 的训练摘要信息。
从配置中获取 GPU 数量。
列出每种问题类型支持的模型。
从 path 指定的目录加载预测器对象。
优化预测器的模型以进行推理。
预测新数据的标签列值。
预测类别概率而非类别标签。
将此预测器保存到 path 指定的目录中的文件。
在配置中设置 GPU 数量。
设置日志的详细程度。
属性
class_labels
类别标签的原始名称。
classes
目标检测问题类型的对象类别。
column_types
pd.DataFrame 中的列类型。
eval_metric
用于评估预测性能的指标。
label
pd.DataFrame 中包含要预测的目标变量的列的名称。
match_label
在语义匹配任务中,表示 <query, response> 对被计为“匹配”的标签类别。
model_size
返回模型的兆字节大小。
path
存储模型和相关文件的目录路径。
positive_class
将映射为 1 的类别标签的名称。
problem_property
问题的属性,存储问题类型及其相关属性。
problem_type
此预测器经过训练解决的预测问题类型。
query
pd.DataFrame 中包含语义匹配任务中查询数据的列的名称。
response
pd.DataFrame 中包含语义匹配任务中响应数据的列的名称。
total_parameters
模型参数的数量。
trainable_parameters
可训练模型参数的数量,通常指那些 requires_grad=True 的参数。
validation_metric
用于在训练期间选择最佳模型和早停的验证指标。
verbosity
详细程度级别范围从 0 到 4,控制打印信息的多少。