多模态数据表:表格、文本和图像

Open In Colab Open In SageMaker Studio Lab

提示:在阅读本教程之前,建议先对 表格列预测 - 快速入门 中涵盖的 TabularPredictor API 有基本的了解。

在本教程中,我们将使用包含图像、文本和表格特征的数据来训练一个多模态集成模型。

注意:本教程需要 GPU 才能训练图像和文本模型。此外,Torch 需要安装适用于相应 CUDA 版本的 GPU。

PetFinder 数据集

我们将使用 PetFinder 数据集。PetFinder 数据集提供了收容所动物在其领养资料中出现的信息,目的是预测动物的领养率。最终目标是让救援收容所使用预测的领养率来识别哪些动物的资料可以改进,以便它们能找到家。

每只动物的领养资料包含各种信息,例如动物的照片、动物的文字描述以及年龄、品种、名称、颜色等各种表格特征。

首先,我们需要下载数据集。包含图像的数据集不仅仅需要 CSV 文件,因此数据集被打包成 S3 中的一个 zip 文件。我们将首先下载并解压其内容。

download_dir = './ag_petfinder_tutorial'
zip_file = 'https://automl-mm-bench.s3.amazonaws.com/petfinder_kaggle.zip'
from autogluon.core.utils.loaders import load_zip
load_zip.unzip(zip_file, unzip_dir=download_dir)
Downloading ./ag_petfinder_tutorial/file.zip from https://automl-mm-bench.s3.amazonaws.com/petfinder_kaggle.zip...
  0%|          | 0.00/2.00G [00:00<?, ?iB/s]
  0%|          | 8.38M/2.00G [00:00<00:46, 43.1MiB/s]
  1%|          | 16.8M/2.00G [00:00<00:42, 46.5MiB/s]
  1%|▏         | 25.2M/2.00G [00:00<00:36, 53.7MiB/s]
  2%|▏         | 32.8M/2.00G [00:00<00:32, 60.4MiB/s]
  2%|▏         | 39.2M/2.00G [00:00<00:36, 53.1MiB/s]
  2%|▏         | 45.4M/2.00G [00:00<00:35, 55.7MiB/s]
  3%|▎         | 51.2M/2.00G [00:01<00:41, 46.6MiB/s]
  3%|▎         | 57.0M/2.00G [00:01<00:42, 45.5MiB/s]
  3%|▎         | 61.7M/2.00G [00:01<00:48, 39.5MiB/s]
  3%|▎         | 65.9M/2.00G [00:01<00:50, 37.9MiB/s]
  3%|▎         | 69.8M/2.00G [00:01<00:59, 32.1MiB/s]
  4%|▍         | 75.5M/2.00G [00:01<00:57, 33.6MiB/s]
  4%|▍         | 82.1M/2.00G [00:01<00:53, 36.1MiB/s]
  4%|▍         | 85.8M/2.00G [00:02<00:57, 33.3MiB/s]
  5%|▍         | 90.9M/2.00G [00:02<00:58, 32.7MiB/s]
  5%|▍         | 94.2M/2.00G [00:02<01:00, 31.2MiB/s]
  5%|▍         | 97.3M/2.00G [00:02<01:19, 24.0MiB/s]
  5%|▌         | 102M/2.00G [00:02<01:05, 28.8MiB/s]
  5%|▌         | 109M/2.00G [00:02<01:02, 30.3MiB/s]
  6%|▌         | 117M/2.00G [00:03<00:47, 39.5MiB/s]
  6%|▋         | 126M/2.00G [00:03<00:43, 42.9MiB/s]
  7%|▋         | 134M/2.00G [00:03<00:41, 44.9MiB/s]
  7%|▋         | 143M/2.00G [00:03<00:36, 51.3MiB/s]
  8%|▊         | 151M/2.00G [00:03<00:44, 41.7MiB/s]
  8%|▊         | 159M/2.00G [00:04<00:48, 37.5MiB/s]
  8%|▊         | 166M/2.00G [00:04<00:56, 32.5MiB/s]
  9%|▊         | 170M/2.00G [00:04<00:58, 31.2MiB/s]
  9%|▊         | 174M/2.00G [00:04<00:54, 33.5MiB/s]
  9%|▉         | 178M/2.00G [00:04<00:59, 30.8MiB/s]
  9%|▉         | 183M/2.00G [00:04<00:57, 31.3MiB/s]
  9%|▉         | 186M/2.00G [00:04<00:59, 30.3MiB/s]
 10%|▉         | 193M/2.00G [00:05<00:52, 34.4MiB/s]
 10%|█         | 200M/2.00G [00:05<00:46, 39.0MiB/s]
 10%|█         | 204M/2.00G [00:05<00:50, 35.3MiB/s]
 10%|█         | 209M/2.00G [00:05<00:45, 39.3MiB/s]
 11%|█         | 213M/2.00G [00:05<00:55, 32.0MiB/s]
 11%|█         | 218M/2.00G [00:05<00:57, 30.9MiB/s]
 11%|█▏        | 225M/2.00G [00:06<00:48, 36.1MiB/s]
 11%|█▏        | 229M/2.00G [00:06<00:52, 33.7MiB/s]
 12%|█▏        | 235M/2.00G [00:06<00:46, 38.2MiB/s]
 12%|█▏        | 243M/2.00G [00:06<00:39, 44.4MiB/s]
 13%|█▎        | 252M/2.00G [00:06<00:34, 50.7MiB/s]
 13%|█▎        | 259M/2.00G [00:06<00:32, 53.6MiB/s]
 13%|█▎        | 264M/2.00G [00:06<00:35, 48.3MiB/s]
 13%|█▎        | 269M/2.00G [00:06<00:35, 48.8MiB/s]
 14%|█▍        | 275M/2.00G [00:07<00:34, 50.4MiB/s]
 14%|█▍        | 280M/2.00G [00:07<00:40, 41.9MiB/s]
 14%|█▍        | 285M/2.00G [00:07<00:45, 37.4MiB/s]
 14%|█▍        | 289M/2.00G [00:07<00:46, 36.4MiB/s]
 15%|█▍        | 294M/2.00G [00:07<00:50, 33.5MiB/s]
 15%|█▌        | 300M/2.00G [00:07<00:44, 38.0MiB/s]
 15%|█▌        | 304M/2.00G [00:07<00:52, 31.9MiB/s]
 16%|█▌        | 310M/2.00G [00:08<00:44, 38.0MiB/s]
 16%|█▌        | 319M/2.00G [00:08<00:36, 45.4MiB/s]
 16%|█▋        | 329M/2.00G [00:08<00:28, 59.0MiB/s]
 17%|█▋        | 336M/2.00G [00:08<00:35, 46.6MiB/s]
 17%|█▋        | 343M/2.00G [00:08<00:39, 42.0MiB/s]
 17%|█▋        | 347M/2.00G [00:08<00:44, 37.0MiB/s]
 18%|█▊        | 352M/2.00G [00:09<00:49, 33.1MiB/s]
 18%|█▊        | 361M/2.00G [00:09<00:42, 38.8MiB/s]
 18%|█▊        | 369M/2.00G [00:09<00:34, 47.0MiB/s]
 19%|█▊        | 374M/2.00G [00:09<00:37, 43.5MiB/s]
 19%|█▉        | 379M/2.00G [00:09<00:40, 40.3MiB/s]
 19%|█▉        | 386M/2.00G [00:09<00:40, 39.9MiB/s]
 20%|█▉        | 394M/2.00G [00:10<00:36, 43.9MiB/s]
 20%|██        | 403M/2.00G [00:10<00:33, 47.0MiB/s]
 21%|██        | 409M/2.00G [00:10<00:40, 39.1MiB/s]
 21%|██        | 413M/2.00G [00:10<00:43, 36.3MiB/s]
 21%|██        | 419M/2.00G [00:10<00:42, 37.1MiB/s]
 21%|██▏       | 428M/2.00G [00:10<00:42, 37.3MiB/s]
 22%|██▏       | 435M/2.00G [00:11<00:44, 34.7MiB/s]
 22%|██▏       | 438M/2.00G [00:11<01:05, 23.7MiB/s]
 22%|██▏       | 445M/2.00G [00:11<01:05, 23.7MiB/s]
 23%|██▎       | 454M/2.00G [00:11<00:46, 33.3MiB/s]
 23%|██▎       | 461M/2.00G [00:12<00:41, 37.0MiB/s]
 23%|██▎       | 466M/2.00G [00:12<00:39, 38.3MiB/s]
 24%|██▎       | 471M/2.00G [00:12<00:42, 35.8MiB/s]
 24%|██▍       | 477M/2.00G [00:12<00:38, 39.4MiB/s]
 24%|██▍       | 481M/2.00G [00:12<00:43, 34.6MiB/s]
 24%|██▍       | 485M/2.00G [00:12<00:45, 33.0MiB/s]
 25%|██▍       | 489M/2.00G [00:12<00:42, 35.3MiB/s]
 25%|██▍       | 495M/2.00G [00:13<00:47, 31.4MiB/s]
 25%|██▌       | 503M/2.00G [00:13<00:41, 36.1MiB/s]
 26%|██▌       | 512M/2.00G [00:13<00:34, 43.5MiB/s]
 26%|██▌       | 520M/2.00G [00:13<00:30, 47.6MiB/s]
 26%|██▋       | 528M/2.00G [00:13<00:31, 46.8MiB/s]
 27%|██▋       | 535M/2.00G [00:13<00:38, 38.0MiB/s]
 27%|██▋       | 539M/2.00G [00:14<00:40, 36.2MiB/s]
 27%|██▋       | 545M/2.00G [00:14<00:36, 39.5MiB/s]
 28%|██▊       | 549M/2.00G [00:14<00:40, 35.9MiB/s]
 28%|██▊       | 554M/2.00G [00:14<00:44, 32.7MiB/s]
 28%|██▊       | 561M/2.00G [00:14<00:36, 39.7MiB/s]
 28%|██▊       | 565M/2.00G [00:14<00:38, 37.1MiB/s]
 29%|██▊       | 570M/2.00G [00:14<00:39, 36.1MiB/s]
 29%|██▉       | 577M/2.00G [00:15<00:34, 41.2MiB/s]
 29%|██▉       | 581M/2.00G [00:15<00:39, 35.9MiB/s]
 29%|██▉       | 587M/2.00G [00:15<00:36, 38.9MiB/s]
 30%|██▉       | 596M/2.00G [00:15<00:33, 42.4MiB/s]
 30%|███       | 604M/2.00G [00:15<00:30, 45.3MiB/s]
 31%|███       | 611M/2.00G [00:15<00:28, 47.8MiB/s]
 31%|███       | 615M/2.00G [00:15<00:34, 39.9MiB/s]
 31%|███       | 620M/2.00G [00:16<00:34, 39.9MiB/s]
 31%|███▏      | 624M/2.00G [00:16<00:36, 37.5MiB/s]
 31%|███▏      | 628M/2.00G [00:16<00:48, 28.3MiB/s]
 32%|███▏      | 631M/2.00G [00:16<00:48, 28.0MiB/s]
 32%|███▏      | 638M/2.00G [00:16<00:41, 32.9MiB/s]
 32%|███▏      | 646M/2.00G [00:16<00:31, 43.4MiB/s]
 33%|███▎      | 654M/2.00G [00:16<00:26, 50.2MiB/s]
 33%|███▎      | 663M/2.00G [00:17<00:29, 45.1MiB/s]
 34%|███▎      | 671M/2.00G [00:17<00:26, 50.8MiB/s]
 34%|███▍      | 678M/2.00G [00:17<00:25, 51.3MiB/s]
 34%|███▍      | 683M/2.00G [00:17<00:28, 45.9MiB/s]
 34%|███▍      | 688M/2.00G [00:17<00:29, 44.4MiB/s]
 35%|███▍      | 696M/2.00G [00:17<00:30, 43.2MiB/s]
 35%|███▌      | 703M/2.00G [00:18<00:29, 43.3MiB/s]
 35%|███▌      | 707M/2.00G [00:18<00:31, 40.5MiB/s]
 36%|███▌      | 713M/2.00G [00:18<00:32, 40.0MiB/s]
 36%|███▌      | 721M/2.00G [00:18<00:26, 48.4MiB/s]
 36%|███▋      | 728M/2.00G [00:18<00:29, 43.6MiB/s]
 37%|███▋      | 733M/2.00G [00:18<00:32, 38.5MiB/s]
 37%|███▋      | 738M/2.00G [00:18<00:29, 42.3MiB/s]
 37%|███▋      | 747M/2.00G [00:19<00:26, 46.9MiB/s]
 38%|███▊      | 754M/2.00G [00:19<00:23, 52.2MiB/s]
 38%|███▊      | 759M/2.00G [00:19<00:27, 44.7MiB/s]
 38%|███▊      | 764M/2.00G [00:19<00:35, 34.5MiB/s]
 39%|███▊      | 772M/2.00G [00:19<00:31, 38.8MiB/s]
 39%|███▉      | 779M/2.00G [00:19<00:33, 36.6MiB/s]
 39%|███▉      | 783M/2.00G [00:20<00:35, 33.9MiB/s]
 40%|███▉      | 789M/2.00G [00:20<00:34, 34.7MiB/s]
 40%|███▉      | 796M/2.00G [00:20<00:33, 35.9MiB/s]
 40%|████      | 799M/2.00G [00:20<00:34, 35.0MiB/s]
 40%|████      | 805M/2.00G [00:20<00:31, 37.5MiB/s]
 41%|████      | 814M/2.00G [00:20<00:26, 45.2MiB/s]
 41%|████      | 821M/2.00G [00:20<00:24, 47.5MiB/s]
 41%|████▏     | 826M/2.00G [00:21<00:28, 41.7MiB/s]
 42%|████▏     | 830M/2.00G [00:21<00:29, 40.2MiB/s]
 42%|████▏     | 834M/2.00G [00:21<00:30, 37.7MiB/s]
 42%|████▏     | 839M/2.00G [00:21<00:32, 36.1MiB/s]
 42%|████▏     | 847M/2.00G [00:21<00:30, 37.1MiB/s]
 43%|████▎     | 856M/2.00G [00:21<00:28, 40.1MiB/s]
 43%|████▎     | 862M/2.00G [00:22<00:29, 38.8MiB/s]
 43%|████▎     | 866M/2.00G [00:22<00:33, 34.2MiB/s]
 44%|████▎     | 872M/2.00G [00:22<00:31, 35.7MiB/s]
 44%|████▍     | 879M/2.00G [00:22<00:26, 42.6MiB/s]
 44%|████▍     | 885M/2.00G [00:22<00:24, 45.2MiB/s]
 45%|████▍     | 890M/2.00G [00:22<00:33, 32.5MiB/s]
 45%|████▍     | 897M/2.00G [00:22<00:27, 39.9MiB/s]
 45%|████▌     | 901M/2.00G [00:23<00:31, 35.1MiB/s]
 45%|████▌     | 908M/2.00G [00:23<00:26, 40.6MiB/s]
 46%|████▌     | 914M/2.00G [00:23<00:25, 41.8MiB/s]
 46%|████▌     | 923M/2.00G [00:23<00:23, 45.4MiB/s]
 47%|████▋     | 931M/2.00G [00:23<00:24, 43.9MiB/s]
 47%|████▋     | 940M/2.00G [00:23<00:23, 44.9MiB/s]
 48%|████▊     | 948M/2.00G [00:24<00:23, 44.0MiB/s]
 48%|████▊     | 955M/2.00G [00:24<00:21, 47.4MiB/s]
 48%|████▊     | 960M/2.00G [00:24<00:22, 46.0MiB/s]
 48%|████▊     | 965M/2.00G [00:24<00:25, 39.7MiB/s]
 49%|████▉     | 973M/2.00G [00:24<00:25, 40.9MiB/s]
 49%|████▉     | 981M/2.00G [00:24<00:21, 46.2MiB/s]
 49%|████▉     | 986M/2.00G [00:25<00:25, 39.3MiB/s]
 50%|████▉     | 990M/2.00G [00:25<00:35, 28.5MiB/s]
 50%|█████     | 998M/2.00G [00:25<00:31, 32.1MiB/s]
 50%|█████     | 1.01G/2.00G [00:25<00:29, 33.9MiB/s]
 51%|█████     | 1.01G/2.00G [00:25<00:26, 37.6MiB/s]
 51%|█████     | 1.02G/2.00G [00:26<00:29, 33.2MiB/s]
 51%|█████     | 1.02G/2.00G [00:26<00:34, 27.9MiB/s]
 51%|█████▏    | 1.03G/2.00G [00:26<00:39, 24.3MiB/s]
 52%|█████▏    | 1.03G/2.00G [00:26<00:53, 18.0MiB/s]
 52%|█████▏    | 1.03G/2.00G [00:27<00:49, 19.4MiB/s]
 52%|█████▏    | 1.04G/2.00G [00:27<00:32, 29.2MiB/s]
 52%|█████▏    | 1.05G/2.00G [00:27<00:27, 34.3MiB/s]
 53%|█████▎    | 1.05G/2.00G [00:27<00:37, 25.1MiB/s]
 53%|█████▎    | 1.06G/2.00G [00:27<00:30, 30.4MiB/s]
 53%|█████▎    | 1.06G/2.00G [00:27<00:24, 38.1MiB/s]
 54%|█████▎    | 1.07G/2.00G [00:27<00:25, 36.6MiB/s]
 54%|█████▍    | 1.07G/2.00G [00:28<00:24, 38.1MiB/s]
 54%|█████▍    | 1.08G/2.00G [00:28<00:19, 46.7MiB/s]
 55%|█████▍    | 1.09G/2.00G [00:28<00:17, 50.4MiB/s]
 55%|█████▌    | 1.10G/2.00G [00:28<00:22, 40.1MiB/s]
 55%|█████▌    | 1.10G/2.00G [00:28<00:27, 32.0MiB/s]
 55%|█████▌    | 1.11G/2.00G [00:29<00:32, 27.6MiB/s]
 56%|█████▌    | 1.12G/2.00G [00:29<00:25, 34.2MiB/s]
 56%|█████▋    | 1.12G/2.00G [00:29<00:20, 42.9MiB/s]
 57%|█████▋    | 1.13G/2.00G [00:29<00:21, 40.1MiB/s]
 57%|█████▋    | 1.13G/2.00G [00:29<00:21, 40.3MiB/s]
 57%|█████▋    | 1.14G/2.00G [00:29<00:20, 41.2MiB/s]
 57%|█████▋    | 1.14G/2.00G [00:29<00:25, 32.9MiB/s]
 58%|█████▊    | 1.15G/2.00G [00:30<00:30, 27.4MiB/s]
 58%|█████▊    | 1.15G/2.00G [00:30<00:34, 24.8MiB/s]
 58%|█████▊    | 1.16G/2.00G [00:30<00:26, 32.0MiB/s]
 58%|█████▊    | 1.17G/2.00G [00:30<00:20, 40.0MiB/s]
 59%|█████▉    | 1.17G/2.00G [00:30<00:18, 43.8MiB/s]
 59%|█████▉    | 1.18G/2.00G [00:30<00:21, 38.4MiB/s]
 59%|█████▉    | 1.18G/2.00G [00:31<00:21, 37.4MiB/s]
 60%|█████▉    | 1.19G/2.00G [00:31<00:19, 41.3MiB/s]
 60%|██████    | 1.20G/2.00G [00:31<00:17, 46.5MiB/s]
 61%|██████    | 1.21G/2.00G [00:31<00:16, 48.5MiB/s]
 61%|██████    | 1.21G/2.00G [00:31<00:14, 52.2MiB/s]
 61%|██████    | 1.22G/2.00G [00:31<00:17, 45.3MiB/s]
 61%|██████▏   | 1.22G/2.00G [00:31<00:17, 44.5MiB/s]
 62%|██████▏   | 1.23G/2.00G [00:32<00:14, 53.0MiB/s]
 62%|██████▏   | 1.24G/2.00G [00:32<00:16, 44.6MiB/s]
 62%|██████▏   | 1.24G/2.00G [00:32<00:22, 33.8MiB/s]
 63%|██████▎   | 1.25G/2.00G [00:32<00:21, 35.2MiB/s]
 63%|██████▎   | 1.25G/2.00G [00:32<00:22, 32.9MiB/s]
 63%|██████▎   | 1.26G/2.00G [00:32<00:23, 31.2MiB/s]
 63%|██████▎   | 1.27G/2.00G [00:33<00:20, 36.2MiB/s]
 64%|██████▍   | 1.27G/2.00G [00:33<00:21, 33.9MiB/s]
 64%|██████▍   | 1.28G/2.00G [00:33<00:22, 31.7MiB/s]
 64%|██████▍   | 1.28G/2.00G [00:33<00:22, 31.6MiB/s]
 65%|██████▍   | 1.29G/2.00G [00:33<00:20, 34.8MiB/s]
 65%|██████▌   | 1.30G/2.00G [00:33<00:16, 41.9MiB/s]
 66%|██████▌   | 1.31G/2.00G [00:34<00:15, 44.6MiB/s]
 66%|██████▌   | 1.32G/2.00G [00:34<00:15, 43.6MiB/s]
 66%|██████▌   | 1.32G/2.00G [00:34<00:17, 39.5MiB/s]
 66%|██████▋   | 1.33G/2.00G [00:34<00:17, 38.6MiB/s]
 67%|██████▋   | 1.33G/2.00G [00:34<00:15, 42.2MiB/s]
 67%|██████▋   | 1.34G/2.00G [00:34<00:17, 38.4MiB/s]
 67%|██████▋   | 1.34G/2.00G [00:35<00:18, 35.0MiB/s]
 68%|██████▊   | 1.35G/2.00G [00:35<00:16, 39.4MiB/s]
 68%|██████▊   | 1.36G/2.00G [00:35<00:14, 44.5MiB/s]
 68%|██████▊   | 1.36G/2.00G [00:35<00:17, 37.1MiB/s]
 68%|██████▊   | 1.37G/2.00G [00:35<00:17, 35.0MiB/s]
 69%|██████▊   | 1.37G/2.00G [00:35<00:19, 32.5MiB/s]
 69%|██████▉   | 1.38G/2.00G [00:36<00:17, 35.2MiB/s]
 69%|██████▉   | 1.38G/2.00G [00:36<00:14, 42.0MiB/s]
 70%|██████▉   | 1.39G/2.00G [00:36<00:13, 43.8MiB/s]
 70%|███████   | 1.40G/2.00G [00:36<00:13, 44.2MiB/s]
 71%|███████   | 1.41G/2.00G [00:36<00:12, 46.7MiB/s]
 71%|███████   | 1.41G/2.00G [00:36<00:14, 40.8MiB/s]
 71%|███████   | 1.42G/2.00G [00:36<00:14, 38.9MiB/s]
 71%|███████▏  | 1.42G/2.00G [00:37<00:15, 37.8MiB/s]
 72%|███████▏  | 1.43G/2.00G [00:37<00:17, 32.4MiB/s]
 72%|███████▏  | 1.43G/2.00G [00:37<00:14, 38.2MiB/s]
 72%|███████▏  | 1.44G/2.00G [00:37<00:11, 49.3MiB/s]
 73%|███████▎  | 1.45G/2.00G [00:37<00:10, 52.2MiB/s]
 73%|███████▎  | 1.46G/2.00G [00:37<00:11, 46.1MiB/s]
 74%|███████▎  | 1.47G/2.00G [00:38<00:11, 46.8MiB/s]
 74%|███████▍  | 1.48G/2.00G [00:38<00:09, 52.6MiB/s]
 74%|███████▍  | 1.48G/2.00G [00:38<00:09, 52.4MiB/s]
 75%|███████▍  | 1.49G/2.00G [00:38<00:10, 46.9MiB/s]
 75%|███████▌  | 1.50G/2.00G [00:38<00:11, 43.0MiB/s]
 76%|███████▌  | 1.51G/2.00G [00:38<00:10, 45.1MiB/s]
 76%|███████▌  | 1.52G/2.00G [00:39<00:10, 44.8MiB/s]
 76%|███████▋  | 1.53G/2.00G [00:39<00:10, 45.5MiB/s]
 77%|███████▋  | 1.53G/2.00G [00:39<00:11, 40.4MiB/s]
 77%|███████▋  | 1.53G/2.00G [00:39<00:12, 37.1MiB/s]
 77%|███████▋  | 1.54G/2.00G [00:39<00:14, 31.5MiB/s]
 77%|███████▋  | 1.54G/2.00G [00:39<00:14, 31.5MiB/s]
 78%|███████▊  | 1.55G/2.00G [00:40<00:12, 35.8MiB/s]
 78%|███████▊  | 1.56G/2.00G [00:40<00:10, 40.9MiB/s]
 79%|███████▊  | 1.57G/2.00G [00:40<00:09, 44.1MiB/s]
 79%|███████▉  | 1.58G/2.00G [00:40<00:09, 43.6MiB/s]
 79%|███████▉  | 1.59G/2.00G [00:40<00:08, 47.4MiB/s]
 80%|███████▉  | 1.59G/2.00G [00:41<00:09, 41.8MiB/s]
 80%|████████  | 1.60G/2.00G [00:41<00:09, 39.7MiB/s]
 80%|████████  | 1.60G/2.00G [00:41<00:12, 30.6MiB/s]
 81%|████████  | 1.61G/2.00G [00:41<00:11, 33.4MiB/s]
 81%|████████  | 1.62G/2.00G [00:41<00:10, 34.4MiB/s]
 82%|████████▏ | 1.63G/2.00G [00:42<00:09, 38.1MiB/s]
 82%|████████▏ | 1.64G/2.00G [00:42<00:08, 44.0MiB/s]
 82%|████████▏ | 1.64G/2.00G [00:42<00:07, 45.0MiB/s]
 83%|████████▎ | 1.65G/2.00G [00:42<00:08, 41.9MiB/s]
 83%|████████▎ | 1.65G/2.00G [00:42<00:09, 36.1MiB/s]
 83%|████████▎ | 1.66G/2.00G [00:42<00:08, 37.3MiB/s]
 83%|████████▎ | 1.66G/2.00G [00:43<00:10, 30.2MiB/s]
 84%|████████▎ | 1.67G/2.00G [00:43<00:09, 34.5MiB/s]
 84%|████████▍ | 1.67G/2.00G [00:43<00:10, 31.8MiB/s]
 84%|████████▍ | 1.68G/2.00G [00:43<00:08, 37.5MiB/s]
 84%|████████▍ | 1.69G/2.00G [00:43<00:09, 33.8MiB/s]
 85%|████████▍ | 1.69G/2.00G [00:43<00:07, 39.5MiB/s]
 85%|████████▌ | 1.70G/2.00G [00:44<00:09, 32.4MiB/s]
 85%|████████▌ | 1.70G/2.00G [00:44<00:08, 35.0MiB/s]
 86%|████████▌ | 1.71G/2.00G [00:44<00:06, 41.2MiB/s]
 86%|████████▌ | 1.71G/2.00G [00:44<00:07, 35.5MiB/s]
 86%|████████▌ | 1.72G/2.00G [00:44<00:07, 35.8MiB/s]
 87%|████████▋ | 1.73G/2.00G [00:44<00:06, 44.6MiB/s]
 87%|████████▋ | 1.74G/2.00G [00:44<00:05, 43.8MiB/s]
 87%|████████▋ | 1.74G/2.00G [00:45<00:05, 46.5MiB/s]
 88%|████████▊ | 1.75G/2.00G [00:45<00:05, 42.4MiB/s]
 88%|████████▊ | 1.75G/2.00G [00:45<00:06, 39.1MiB/s]
 88%|████████▊ | 1.76G/2.00G [00:45<00:05, 45.9MiB/s]
 88%|████████▊ | 1.77G/2.00G [00:45<00:05, 42.7MiB/s]
 89%|████████▊ | 1.77G/2.00G [00:45<00:06, 37.0MiB/s]
 89%|████████▉ | 1.78G/2.00G [00:45<00:05, 40.9MiB/s]
 89%|████████▉ | 1.78G/2.00G [00:46<00:06, 34.8MiB/s]
 90%|████████▉ | 1.79G/2.00G [00:46<00:06, 31.6MiB/s]
 90%|████████▉ | 1.80G/2.00G [00:46<00:05, 39.1MiB/s]
 90%|█████████ | 1.80G/2.00G [00:46<00:04, 41.5MiB/s]
 91%|█████████ | 1.81G/2.00G [00:46<00:04, 40.3MiB/s]
 91%|█████████ | 1.82G/2.00G [00:46<00:03, 44.9MiB/s]
 92%|█████████▏| 1.83G/2.00G [00:47<00:03, 47.0MiB/s]
 92%|█████████▏| 1.84G/2.00G [00:47<00:03, 40.4MiB/s]
 92%|█████████▏| 1.85G/2.00G [00:47<00:03, 46.0MiB/s]
 93%|█████████▎| 1.85G/2.00G [00:47<00:03, 40.0MiB/s]
 93%|█████████▎| 1.86G/2.00G [00:48<00:03, 38.8MiB/s]
 94%|█████████▎| 1.87G/2.00G [00:48<00:03, 40.4MiB/s]
 94%|█████████▍| 1.87G/2.00G [00:48<00:03, 31.7MiB/s]
 94%|█████████▍| 1.88G/2.00G [00:48<00:03, 29.6MiB/s]
 94%|█████████▍| 1.88G/2.00G [00:48<00:03, 29.8MiB/s]
 95%|█████████▍| 1.89G/2.00G [00:48<00:03, 31.5MiB/s]
 95%|█████████▍| 1.89G/2.00G [00:49<00:02, 37.1MiB/s]
 95%|█████████▌| 1.90G/2.00G [00:49<00:02, 35.1MiB/s]
 95%|█████████▌| 1.90G/2.00G [00:49<00:02, 34.7MiB/s]
 96%|█████████▌| 1.91G/2.00G [00:49<00:02, 31.4MiB/s]
 96%|█████████▌| 1.91G/2.00G [00:49<00:02, 33.6MiB/s]
 96%|█████████▋| 1.92G/2.00G [00:49<00:01, 37.4MiB/s]
 97%|█████████▋| 1.93G/2.00G [00:49<00:01, 43.4MiB/s]
 97%|█████████▋| 1.94G/2.00G [00:50<00:01, 48.6MiB/s]
 97%|█████████▋| 1.94G/2.00G [00:50<00:01, 43.4MiB/s]
 98%|█████████▊| 1.95G/2.00G [00:50<00:01, 39.2MiB/s]
 98%|█████████▊| 1.95G/2.00G [00:50<00:01, 36.0MiB/s]
 98%|█████████▊| 1.96G/2.00G [00:50<00:00, 40.0MiB/s]
 99%|█████████▉| 1.97G/2.00G [00:50<00:00, 47.2MiB/s]
 99%|█████████▉| 1.98G/2.00G [00:51<00:00, 47.7MiB/s]
100%|█████████▉| 1.99G/2.00G [00:51<00:00, 44.1MiB/s]
100%|██████████| 2.00G/2.00G [00:51<00:00, 38.8MiB/s]

现在数据已下载并解压,让我们看一下其内容。

import os
os.listdir(download_dir)
['file.zip', 'petfinder_processed']

‘file.zip’ 是我们下载的原始 zip 文件,而 ‘petfinder_processed’ 是一个包含数据集文件的目录。

dataset_path = download_dir + '/petfinder_processed'
os.listdir(dataset_path)
['train.csv', 'train_images', 'test.csv', 'test_images', 'dev.csv']

在这里,我们可以看到 train、test 和 dev CSV 文件,以及两个目录:‘test_images’ 和 ‘train_images’,它们包含 JPG 图像文件。

注意:我们将使用 dev 数据作为测试数据,因为 dev 包含地面实况标签,用于通过 predictor.leaderboard 显示分数。

让我们看一下 ‘train_images’ 目录中的前 10 个文件。

os.listdir(dataset_path + '/train_images')[:10]
['d765ae877-1.jpg',
 '756025f7c-2.jpg',
 'e1a2d9477-4.jpg',
 '6d18707ee-2.jpg',
 '96607bca0-5.jpg',
 'fde58f7fa-10.jpg',
 'be7b65c23-3.jpg',
 'dd36ab692-3.jpg',
 '2d8db1c19-2.jpg',
 '53037f091-2.jpg']

正如预期的那样,这些图像将与其它特征一起用于训练。

接下来,我们将加载 train 和 dev CSV 文件。

import pandas as pd

train_data = pd.read_csv(f'{dataset_path}/train.csv', index_col=0)
test_data = pd.read_csv(f'{dataset_path}/dev.csv', index_col=0)
train_data.head(3)
Type Name Age Breed1 Breed2 Gender Color1 Color2 Color3 MaturitySize ... Quantity Fee State RescuerID VideoAmt Description PetID PhotoAmt AdoptionSpeed Images
10721 1 Elbi 2 307 307 2 5 0 0 3 ... 1 0 41336 e9a86209c54f589ba72c345364cf01aa 0 I'm looking for people to adopt my dog e4b90955c 4.0 4 train_images/e4b90955c-1.jpg;train_images/e4b9...
13114 2 Darling 4 266 0 1 1 0 0 2 ... 1 0 41401 01f954cdf61526daf3fbeb8a074be742 0 Darling was born at the back lane of Jalan Alo... a0c1384d1 5.0 3 train_images/a0c1384d1-1.jpg;train_images/a0c1...
13194 1 Wolf 3 307 0 1 1 2 0 2 ... 1 0 41332 6e19409f2847326ce3b6d0cec7e42f81 0 I found Wolf about a month ago stuck in a drai... cf357f057 7.0 4 train_images/cf357f057-1.jpg;train_images/cf35...

3 行 × 25 列

查看前 3 个示例,我们可以看到数据中包含各种表格特征、文本描述(‘Description’)以及图像路径(‘Images’)。

对于 PetFinder 数据集,我们将尝试预测动物的领养速度(‘AdoptionSpeed’),该速度被分为 5 个类别。这意味着我们正在处理一个多类别分类问题。

label = 'AdoptionSpeed'
image_col = 'Images'

准备图像列

让我们看一下图像列中的值是什么样的。

train_data[image_col].iloc[0]
'train_images/e4b90955c-1.jpg;train_images/e4b90955c-2.jpg;train_images/e4b90955c-3.jpg;train_images/e4b90955c-4.jpg'

目前,AutoGluon 每行仅支持一张图像。由于 PetFinder 数据集每行包含一张或多张图像,我们首先需要预处理图像列,使其每行只包含第一张图像。

train_data[image_col] = train_data[image_col].apply(lambda ele: ele.split(';')[0])
test_data[image_col] = test_data[image_col].apply(lambda ele: ele.split(';')[0])

train_data[image_col].iloc[0]
'train_images/e4b90955c-1.jpg'

AutoGluon 根据图像列提供的文件路径加载图像。

在这里,我们更新路径以指向磁盘上的正确位置。

def path_expander(path, base_folder):
    path_l = path.split(';')
    return ';'.join([os.path.abspath(os.path.join(base_folder, path)) for path in path_l])

train_data[image_col] = train_data[image_col].apply(lambda ele: path_expander(ele, base_folder=dataset_path))
test_data[image_col] = test_data[image_col].apply(lambda ele: path_expander(ele, base_folder=dataset_path))

train_data[image_col].iloc[0]
'/home/ci/autogluon/docs/tutorials/tabular/ag_petfinder_tutorial/petfinder_processed/train_images/e4b90955c-1.jpg'
train_data.head(3)
Type Name Age Breed1 Breed2 Gender Color1 Color2 Color3 MaturitySize ... Quantity Fee State RescuerID VideoAmt Description PetID PhotoAmt AdoptionSpeed Images
10721 1 Elbi 2 307 307 2 5 0 0 3 ... 1 0 41336 e9a86209c54f589ba72c345364cf01aa 0 I'm looking for people to adopt my dog e4b90955c 4.0 4 /home/ci/autogluon/docs/tutorials/tabular/ag_p...
13114 2 Darling 4 266 0 1 1 0 0 2 ... 1 0 41401 01f954cdf61526daf3fbeb8a074be742 0 Darling was born at the back lane of Jalan Alo... a0c1384d1 5.0 3 /home/ci/autogluon/docs/tutorials/tabular/ag_p...
13194 1 Wolf 3 307 0 1 1 2 0 2 ... 1 0 41332 6e19409f2847326ce3b6d0cec7e42f81 0 I found Wolf about a month ago stuck in a drai... cf357f057 7.0 4 /home/ci/autogluon/docs/tutorials/tabular/ag_p...

3 行 × 25 列

分析示例行

现在我们已经预处理了图像列,让我们看一下一个示例数据行并显示文本描述和图片。

example_row = train_data.iloc[1]

example_row
Type                                                             2
Name                                                       Darling
Age                                                              4
Breed1                                                         266
Breed2                                                           0
Gender                                                           1
Color1                                                           1
Color2                                                           0
Color3                                                           0
MaturitySize                                                     2
FurLength                                                        1
Vaccinated                                                       2
Dewormed                                                         2
Sterilized                                                       2
Health                                                           1
Quantity                                                         1
Fee                                                              0
State                                                        41401
RescuerID                         01f954cdf61526daf3fbeb8a074be742
VideoAmt                                                         0
Description      Darling was born at the back lane of Jalan Alo...
PetID                                                    a0c1384d1
PhotoAmt                                                       5.0
AdoptionSpeed                                                    3
Images           /home/ci/autogluon/docs/tutorials/tabular/ag_p...
Name: 13114, dtype: object
example_row['Description']
'Darling was born at the back lane of Jalan Alor and was foster by a feeder. All his siblings had died of accident. His mother and grandmother had just been spayed. Darling make a great condo/apartment cat. He love to play a lot. He would make a great companion for someone looking for a cat to love.'
example_image = example_row['Images']

from IPython.display import Image, display
pil_img = Image(filename=example_image)
display(pil_img)
../../_images/0f51cf2080119477e6ffba340375c2d0decdd4f1361f8aafc490d55c0ff27824.jpg

PetFinder 数据集相当大。出于本教程的目的,我们将抽取 500 行数据用于训练。

在大型多模态数据集上训练计算量可能非常大,特别是如果使用 AutoGluon 中的 best_quality 预设。在原型设计阶段,建议对数据进行抽样,以了解哪些模型值得训练,然后像对待任何其他机器学习算法一样,逐渐使用更多数据和更长时间限制进行训练。

train_data = train_data.sample(500, random_state=0)

构建 FeatureMetadata

接下来,让我们通过从训练数据构建 FeatureMetadata 对象来查看 AutoGluon 推断的特征类型。

from autogluon.tabular import FeatureMetadata
feature_metadata = FeatureMetadata.from_df(train_data)

print(feature_metadata)
('float', [])        :  1 | ['PhotoAmt']
('int', [])          : 19 | ['Type', 'Age', 'Breed1', 'Breed2', 'Gender', ...]
('object', [])       :  4 | ['Name', 'RescuerID', 'PetID', 'Images']
('object', ['text']) :  1 | ['Description']

请注意,FeatureMetadata 自动将列‘Description’识别为文本,因此我们无需手动指定它是文本。

为了利用图像,我们需要告诉 AutoGluon 哪一列包含图像路径。我们可以通过指定一个 FeatureMetadata 对象并将 ‘image_path’ 特殊类型添加到图像列来实现这一点。稍后我们将把这个自定义的 FeatureMetadata 传递给 TabularPredictor.fit 方法。

feature_metadata = feature_metadata.add_special_types({image_col: ['image_path']})

print(feature_metadata)
('float', [])              :  1 | ['PhotoAmt']
('int', [])                : 19 | ['Type', 'Age', 'Breed1', 'Breed2', 'Gender', ...]
('object', [])             :  3 | ['Name', 'RescuerID', 'PetID']
('object', ['image_path']) :  1 | ['Images']
('object', ['text'])       :  1 | ['Description']

指定超参数

接下来,我们需要指定要用于训练的模型。这通过 TabularPredictor.fit 方法的 hyperparameters 参数来实现。

AutoGluon 有一个预定义的配置,适用于多模态数据集,称为 'multimodal'。我们可以通过以下方式访问它:

from autogluon.tabular.configs.hyperparameter_configs import get_hyperparameter_config
hyperparameters = get_hyperparameter_config('multimodal')

hyperparameters
{'NN_TORCH': {},
 'GBM': [{},
  {'extra_trees': True, 'ag_args': {'name_suffix': 'XT'}},
  {'learning_rate': 0.03,
   'num_leaves': 128,
   'feature_fraction': 0.9,
   'min_data_in_leaf': 3,
   'ag_args': {'name_suffix': 'Large',
    'priority': 0,
    'hyperparameter_tune_kwargs': None}}],
 'CAT': {},
 'XGB': {},
 'AG_AUTOMM': {}}

此超参数配置将训练各种表格模型,并对 Electra BERT 文本模型和 ResNet 图像模型进行微调。

使用 TabularPredictor 进行拟合

现在我们将使用之前定义的特征元数据和超参数在数据集上训练一个 TabularPredictor。此 TabularPredictor 将同时利用表格、文本和图像特征。

from autogluon.tabular import TabularPredictor
predictor = TabularPredictor(label=label).fit(
    train_data=train_data,
    hyperparameters=hyperparameters,
    feature_metadata=feature_metadata,
    time_limit=900,
)
No path specified. Models will be saved in: "AutogluonModels/ag-20250508_220305"
Verbosity: 2 (Standard Logging)
=================== System Info ===================
AutoGluon Version:  1.3.1b20250508
Python Version:     3.11.9
Operating System:   Linux
Platform Machine:   x86_64
Platform Version:   #1 SMP Wed Mar 12 14:53:59 UTC 2025
CPU Count:          8
Memory Avail:       28.66 GB / 30.95 GB (92.6%)
Disk Space Avail:   206.28 GB / 255.99 GB (80.6%)
===================================================
No presets specified! To achieve strong results with AutoGluon, it is recommended to use the available presets. Defaulting to `'medium'`...
	Recommended Presets (For more details refer to https://autogluon.cn/stable/tutorials/tabular/tabular-essentials.html#presets):
	presets='experimental' : New in v1.2: Pre-trained foundation model + parallel fits. The absolute best accuracy without consideration for inference speed. Does not support GPU.
	presets='best'         : Maximize accuracy. Recommended for most users. Use in competitions and benchmarks.
	presets='high'         : Strong accuracy with fast inference speed.
	presets='good'         : Good accuracy with very fast inference speed.
	presets='medium'       : Fast training time, ideal for initial prototyping.
Beginning AutoGluon training ... Time limit = 900s
AutoGluon will save models to "/home/ci/autogluon/docs/tutorials/tabular/AutogluonModels/ag-20250508_220305"
Train Data Rows:    500
Train Data Columns: 24
Label Column:       AdoptionSpeed
AutoGluon infers your prediction problem is: 'multiclass' (because dtype of label-column == int, but few unique label-values observed).
	5 unique label values:  [np.int64(2), np.int64(3), np.int64(4), np.int64(0), np.int64(1)]
	If 'multiclass' is not the correct problem_type, please manually specify the problem_type parameter during Predictor init (You may specify problem_type as one of: ['binary', 'multiclass', 'regression', 'quantile'])
Problem Type:       multiclass
Preprocessing data ...
Train Data Class Count: 5
Using Feature Generators to preprocess the data ...
Fitting AutoMLPipelineFeatureGenerator...
	Available Memory:                    29343.56 MB
	Train Data (Original)  Memory Usage: 0.45 MB (0.0% of available memory)
	Stage 1 Generators:
		Fitting AsTypeFeatureGenerator...
			Note: Converting 1 features to boolean dtype as they only contain 2 unique values.
	Stage 2 Generators:
		Fitting FillNaFeatureGenerator...
	Stage 3 Generators:
		Fitting IdentityFeatureGenerator...
		Fitting IdentityFeatureGenerator...
			Fitting RenameFeatureGenerator...
		Fitting CategoryFeatureGenerator...
			Fitting CategoryMemoryMinimizeFeatureGenerator...
		Fitting TextSpecialFeatureGenerator...
			Fitting BinnedFeatureGenerator...
			Fitting DropDuplicatesFeatureGenerator...
		Fitting TextNgramFeatureGenerator...
			Fitting CountVectorizer for text features: ['Description']
			CountVectorizer fit with vocabulary size = 170
		Fitting IdentityFeatureGenerator...
		Fitting IsNanFeatureGenerator...
	Stage 4 Generators:
		Fitting DropUniqueFeatureGenerator...
	Stage 5 Generators:
		Fitting DropDuplicatesFeatureGenerator...
	Unused Original Features (Count: 1): ['PetID']
		These features were not used to generate any of the output features. Add a feature generator compatible with these features to utilize them.
		Features can also be unused if they carry very little information, such as being categorical but having almost entirely unique values or being duplicates of other features.
		These features do not need to be present at inference time.
		('object', []) : 1 | ['PetID']
	Types of features in original data (raw dtype, special dtypes):
		('float', [])              :  1 | ['PhotoAmt']
		('int', [])                : 18 | ['Type', 'Age', 'Breed1', 'Breed2', 'Gender', ...]
		('object', [])             :  2 | ['Name', 'RescuerID']
		('object', ['image_path']) :  1 | ['Images']
		('object', ['text'])       :  1 | ['Description']
	Types of features in processed data (raw dtype, special dtypes):
		('category', [])                    :   2 | ['Name', 'RescuerID']
		('category', ['text_as_category'])  :   1 | ['Description']
		('float', [])                       :   1 | ['PhotoAmt']
		('int', [])                         :  17 | ['Age', 'Breed1', 'Breed2', 'Gender', 'Color1', ...]
		('int', ['binned', 'text_special']) :  24 | ['Description.char_count', 'Description.word_count', 'Description.capital_ratio', 'Description.lower_ratio', 'Description.digit_ratio', ...]
		('int', ['bool'])                   :   1 | ['Type']
		('int', ['text_ngram'])             : 171 | ['__nlp__.about', '__nlp__.active', '__nlp__.active and', '__nlp__.adopt', '__nlp__.adopted', ...]
		('object', ['image_path'])          :   1 | ['Images']
		('object', ['text'])                :   1 | ['Description_raw_text']
	1.6s = Fit runtime
	23 features in original data used to generate 219 features in processed data.
	Train Data (Processed) Memory Usage: 0.52 MB (0.0% of available memory)
Data preprocessing and feature engineering runtime = 1.67s ...
AutoGluon will gauge predictive performance using evaluation metric: 'accuracy'
	To change this, specify the eval_metric parameter of Predictor()
Automatically generating train/validation split with holdout_frac=0.2, Train Rows: 400, Val Rows: 100
User-specified model hyperparameters to be fit:
{
	'NN_TORCH': [{}],
	'GBM': [{}, {'extra_trees': True, 'ag_args': {'name_suffix': 'XT'}}, {'learning_rate': 0.03, 'num_leaves': 128, 'feature_fraction': 0.9, 'min_data_in_leaf': 3, 'ag_args': {'name_suffix': 'Large', 'priority': 0, 'hyperparameter_tune_kwargs': None}}],
	'CAT': [{}],
	'XGB': [{}],
	'AG_AUTOMM': [{}],
}
Fitting 7 L1 models, fit_strategy="sequential" ...
Fitting model: LightGBM ... Training model for up to 898.33s of the 898.33s of remaining time.
	0.35	 = Validation score   (accuracy)
	0.79s	 = Training   runtime
	0.0s	 = Validation runtime
Fitting model: LightGBMXT ... Training model for up to 897.53s of the 897.53s of remaining time.
	0.34	 = Validation score   (accuracy)
	0.61s	 = Training   runtime
	0.0s	 = Validation runtime
Fitting model: CatBoost ... Training model for up to 896.91s of the 896.90s of remaining time.
	0.31	 = Validation score   (accuracy)
	2.21s	 = Training   runtime
	0.01s	 = Validation runtime
Fitting model: XGBoost ... Training model for up to 894.68s of the 894.68s of remaining time.
	0.34	 = Validation score   (accuracy)
	1.17s	 = Training   runtime
	0.01s	 = Validation runtime
Fitting model: NeuralNetTorch ... Training model for up to 893.48s of the 893.48s of remaining time.
	0.34	 = Validation score   (accuracy)
	4.34s	 = Training   runtime
	0.02s	 = Validation runtime
Fitting model: LightGBMLarge ... Training model for up to 889.12s of the 889.11s of remaining time.
	0.36	 = Validation score   (accuracy)
	1.91s	 = Training   runtime
	0.0s	 = Validation runtime
Fitting model: MultiModalPredictor ... Training model for up to 887.20s of the 887.20s of remaining time.
INFO: Seed set to 0
INFO: Using 16bit Automatic Mixed Precision (AMP)
INFO: GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO: HPU available: False, using: 0 HPUs
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO: 
  | Name              | Type                | Params | Mode 
------------------------------------------------------------------
0 | model             | MultimodalFusionMLP | 207 M  | train
1 | validation_metric | MulticlassAccuracy  | 0      | train
2 | loss_func         | CrossEntropyLoss    | 0      | train
------------------------------------------------------------------
207 M     Trainable params
0         Non-trainable params
207 M     Total params
828.189   Total estimated model params size (MB)
1168      Modules in train mode
0         Modules in eval mode
INFO: Epoch 0, global step 1: 'val_accuracy' reached 0.26000 (best 0.26000), saving model to '/home/ci/autogluon/docs/tutorials/tabular/AutogluonModels/ag-20250508_220305/models/MultiModalPredictor/automm_model/epoch=0-step=1.ckpt' as top 3
INFO: Epoch 0, global step 4: 'val_accuracy' reached 0.29000 (best 0.29000), saving model to '/home/ci/autogluon/docs/tutorials/tabular/AutogluonModels/ag-20250508_220305/models/MultiModalPredictor/automm_model/epoch=0-step=4.ckpt' as top 3
INFO: Epoch 1, global step 5: 'val_accuracy' reached 0.29000 (best 0.29000), saving model to '/home/ci/autogluon/docs/tutorials/tabular/AutogluonModels/ag-20250508_220305/models/MultiModalPredictor/automm_model/epoch=1-step=5.ckpt' as top 3
INFO: Epoch 1, global step 8: 'val_accuracy' reached 0.31000 (best 0.31000), saving model to '/home/ci/autogluon/docs/tutorials/tabular/AutogluonModels/ag-20250508_220305/models/MultiModalPredictor/automm_model/epoch=1-step=8.ckpt' as top 3
INFO: Epoch 2, global step 9: 'val_accuracy' reached 0.31000 (best 0.31000), saving model to '/home/ci/autogluon/docs/tutorials/tabular/AutogluonModels/ag-20250508_220305/models/MultiModalPredictor/automm_model/epoch=2-step=9.ckpt' as top 3
INFO: Epoch 2, global step 12: 'val_accuracy' was not in top 3
INFO: Epoch 3, global step 13: 'val_accuracy' reached 0.31000 (best 0.31000), saving model to '/home/ci/autogluon/docs/tutorials/tabular/AutogluonModels/ag-20250508_220305/models/MultiModalPredictor/automm_model/epoch=3-step=13.ckpt' as top 3
INFO: Epoch 3, global step 16: 'val_accuracy' was not in top 3
INFO: Epoch 4, global step 17: 'val_accuracy' was not in top 3
INFO: Epoch 4, global step 20: 'val_accuracy' was not in top 3
INFO: Epoch 5, global step 21: 'val_accuracy' was not in top 3
INFO: Epoch 5, global step 24: 'val_accuracy' was not in top 3
INFO: Epoch 6, global step 25: 'val_accuracy' was not in top 3
INFO: Epoch 6, global step 28: 'val_accuracy' was not in top 3
INFO: Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
INFO: Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
INFO: Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
INFO: Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
	0.31	 = Validation score   (accuracy)
	306.4s	 = Training   runtime
	3.01s	 = Validation runtime
Fitting model: WeightedEnsemble_L2 ... Training model for up to 360.00s of the 574.83s of remaining time.
	Ensemble Weights: {'LightGBMLarge': 0.615, 'XGBoost': 0.308, 'MultiModalPredictor': 0.077}
	0.38	 = Validation score   (accuracy)
	0.06s	 = Training   runtime
	0.0s	 = Validation runtime
AutoGluon training complete, total runtime = 325.26s ... Best model: WeightedEnsemble_L2 | Estimated inference throughput: 33.2 rows/s (100 batch size)
TabularPredictor saved. To load, use: predictor = TabularPredictor.load("/home/ci/autogluon/docs/tutorials/tabular/AutogluonModels/ag-20250508_220305")

拟合预测器后,我们可以查看排行榜,了解各种模型的性能。

leaderboard = predictor.leaderboard(test_data)
Load pretrained checkpoint: /home/ci/autogluon/docs/tutorials/tabular/AutogluonModels/ag-20250508_220305/models/MultiModalPredictor/automm_model/model.ckpt
INFO: Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.

这就是使用 AutoGluon 同时处理图像、文本和表格数据所需的全部步骤!

更多教程,请参考 表格列预测 - 快速入门表格列预测 - 深入了解