使用 AutoMM 进行少样本学习

Open In Colab Open In SageMaker Studio Lab

在本教程中,我们将介绍一种简单但有效的少样本分类方法。我们展示了利用基础模型高质量特征并使用 SVM 进行少样本分类任务的功能。具体来说,我们使用预训练模型提取样本特征,并使用这些特征进行 SVM 学习。我们将展示“基础模型后接 SVM”方法在文本分类数据集和图像分类数据集上的有效性。

少样本文本分类

准备文本数据

我们将所有数据集准备成 pd.DataFrame 格式,就像我们许多教程中所做的那样。在本教程中,我们将使用一个小型 MLDoc 数据集进行演示。该数据集是一个文本分类数据集,包含 4 个类别,我们将训练数据下采样到每个类别 10 个样本,即 10 个样本(shot)。有关 MLDoc 的更多详细信息,请参阅此 链接

import pandas as pd
import os
from autogluon.core.utils.loaders import load_zip

download_dir = "./ag_automm_tutorial_fs_cls"
zip_file = "https://automl-mm-bench.s3.amazonaws.com/nlp_datasets/MLDoc-10shot-en.zip"
load_zip.unzip(zip_file, unzip_dir=download_dir)
dataset_path = os.path.join(download_dir)
train_df = pd.read_csv(f"{dataset_path}/train.csv", names=["label", "text"])
test_df = pd.read_csv(f"{dataset_path}/test.csv", names=["label", "text"])
print(train_df)
print(test_df)
Downloading ./ag_automm_tutorial_fs_cls/file.zip from https://automl-mm-bench.s3.amazonaws.com/nlp_datasets/MLDoc-10shot-en.zip...
   label                                               text
0   GCAT  b'Secretary-General Kofi Annan expressed conce...
1   CCAT  b'The health of ABB Asea Brown Boveri AG\'s Po...
2   GCAT  b'Nepali Prime Minister Lokendra Bahadur Chand...
3   CCAT  b'Integ Inc said Thursday its net loss widened...
4   GCAT  b'These are the leading stories in the Skopje ...
5   ECAT  b'Fears of a slowdown in India\'s industrial g...
6   MCAT  b'The Australian Treasury will offer a total o...
7   CCAT  b'Malaysia\'s Suria Capital Holdings Bhd and M...
8   MCAT  b'The UK gilt repo market had a quiet session ...
9   CCAT  b"Commonwealth Edison Co's (ComEd) 794 megawat...
10  GCAT  b'Police arrested 47 people on Thursday in a c...
11  GCAT  b"Army troops in the Comoros island of Anjouan...
12  ECAT  b"The House Banking Committee is considering w...
13  GCAT  b'A possible international anti-drug centre in...
14  ECAT  b'Angela Knight, economic secretary to the Bri...
15  GCAT  b'Nearly 300 people were feared dead in floods...
16  MCAT  b'The Oslo stock index fell with other Europea...
17  ECAT  b'Morgan Keegan said it won $18.540 million of...
18  CCAT  b'Britons can bank on the phone, bank on the i...
19  CCAT  b"Standard Chartered Bank and Prudential Secur...
20  CCAT  b"United Water Resources Inc said it and Lyonn...
21  ECAT  b'Tanzania on Thursday unveiled its 1997/98 bu...
22  GCAT  b'U.S. President Bill Clinton will meet Prime ...
23  CCAT  b"Pacific Century Regional Developments Ltd sa...
24  MCAT  b'The Athens bourse ended 0.65 percent lower w...
25  ECAT  b'Sri Lanka broad money supply, or M2, is seen...
26  GCAT  b'Collated results of African Nations Cup prel...
27  GCAT  b'Philippine President Fidel Ramos said on Fri...
28  MCAT  b'Shanghai copper futures ended down on heavy ...
29  CCAT  b"Goldman Sachs & Co said on Monday that David...
30  ECAT  b'Maine\'s revenues were higher than forecast ...
31  CCAT  b'Thai animal feedmillers said on Monday they ...
32  MCAT  b"Worldwide trading volume in emerging markets...
33  ECAT  b'One week ended June 25 daily avgs-millions  ...
34  ECAT  b'Algeria\'s non-energy exports reached $688 m...
35  ECAT  b'U.S. seasonally adjusted retail sales rose 1...
36  MCAT  b'The Indonesian rupiah weakened against the d...
37  MCAT  b'Brazilian stocks ended slightly higher led b...
38  MCAT  b'The price of gold hung around the psychologi...
39  MCAT  b'The won closed stronger versus the dollar on...
     label                                               text
0     CCAT  b'RJR Nabisco Holdings Corp has prevailed over...
1     ECAT  b"Britain's economy grew 0.8 percent in the fo...
2     ECAT  b'Slovenia\'s state Institute of Macroeconomic...
3     CCAT  b"Belgium's second largest bank Credit Communa...
4     GCAT  b'The IRA ordered its guerrillas to observe a ...
...    ...                                                ...
3995  CCAT  b"A consortium comprising Itochu Corp and Hanj...
3996  ECAT  b"The volume of Hong Kong's domestic exports i...
3997  ECAT  b'The Danish finance ministry said on Tuesday ...
3998  GCAT  b'A court is to investigate charges that forme...
3999  MCAT  b"German consumers of feed grains, bread rye a...

[4000 rows x 2 columns]
  0%|          | 0.00/2.59M [00:00<?, ?iB/s]
100%|██████████| 2.59M/2.59M [00:00<00:00, 28.0MiB/s]

训练少样本分类器

为了执行少样本分类,我们需要使用 few_shot_classification 问题类型。

from autogluon.multimodal import MultiModalPredictor

predictor_fs_text = MultiModalPredictor(
    problem_type="few_shot_classification",
    label="label",  # column name of the label
    eval_metric="acc",
)
predictor_fs_text.fit(train_df)
scores = predictor_fs_text.evaluate(test_df, metrics=["acc", "f1_macro"])
print(scores)
{'acc': 0.83575, 'f1_macro': 0.8344679316932194}
No path specified. Models will be saved in: "AutogluonModels/ag-20250508_205706"
=================== System Info ===================
AutoGluon Version:  1.3.1b20250508
Python Version:     3.11.9
Operating System:   Linux
Platform Machine:   x86_64
Platform Version:   #1 SMP Wed Mar 12 14:53:59 UTC 2025
CPU Count:          8
Pytorch Version:    2.6.0+cu124
CUDA Version:       12.4
Memory Avail:       28.39 GB / 30.95 GB (91.7%)
Disk Space Avail:   176.95 GB / 255.99 GB (69.1%)
===================================================

AutoMM starts to create your model. ✨✨✨

To track the learning progress, you can open a terminal and launch Tensorboard:
    ```shell
    # Assume you have installed tensorboard
    tensorboard --logdir /home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20250508_205706
    ```
INFO: Seed set to 0
/home/ci/autogluon/multimodal/src/autogluon/multimodal/models/utils.py:1148: UserWarning: provided max length: 512 is smaller than sentence-transformers/all-mpnet-base-v2's default: 514
  warnings.warn(
GPU Count: 1
GPU Count to be Used: 1
INFO: Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
AutoMM has created your model. 🎉🎉🎉

To load the model, use the code below:
    ```python
    from autogluon.multimodal import MultiModalPredictor
    predictor = MultiModalPredictor.load("/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20250508_205706")
    ```

If you are not satisfied with the model, try to increase the training time, 
adjust the hyperparameters (https://autogluon.cn/stable/tutorials/multimodal/advanced_topics/customization.html),
or post issues on GitHub (https://github.com/autogluon/autogluon/issues).
INFO: Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.

与默认分类器比较

让我们使用默认的 classification 问题类型,并与上述方法进行性能比较。

from autogluon.multimodal import MultiModalPredictor

predictor_default_text = MultiModalPredictor(
    label="label",
    problem_type="classification",
    eval_metric="acc",
)
predictor_default_text.fit(train_data=train_df)
scores = predictor_default_text.evaluate(test_df, metrics=["acc", "f1_macro"])
print(scores)
{'acc': 0.58325, 'f1_macro': 0.5327219799909652}
No path specified. Models will be saved in: "AutogluonModels/ag-20250508_205800"
=================== System Info ===================
AutoGluon Version:  1.3.1b20250508
Python Version:     3.11.9
Operating System:   Linux
Platform Machine:   x86_64
Platform Version:   #1 SMP Wed Mar 12 14:53:59 UTC 2025
CPU Count:          8
Pytorch Version:    2.6.0+cu124
CUDA Version:       12.4
Memory Avail:       27.42 GB / 30.95 GB (88.6%)
Disk Space Avail:   175.72 GB / 255.99 GB (68.6%)
===================================================
AutoGluon infers your prediction problem is: 'multiclass' (because dtype of label-column == object).
	4 unique label values:  ['GCAT', 'CCAT', 'ECAT', 'MCAT']
	If 'multiclass' is not the correct problem_type, please manually specify the problem_type parameter during Predictor init (You may specify problem_type as one of: ['binary', 'multiclass', 'regression', 'quantile'])

AutoMM starts to create your model. ✨✨✨

To track the learning progress, you can open a terminal and launch Tensorboard:
    ```shell
    # Assume you have installed tensorboard
    tensorboard --logdir /home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20250508_205800
    ```
INFO: Seed set to 0
GPU Count: 1
GPU Count to be Used: 1
INFO: Using 16bit Automatic Mixed Precision (AMP)
INFO: GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO: HPU available: False, using: 0 HPUs
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO: 
  | Name              | Type                         | Params | Mode 
---------------------------------------------------------------------------
0 | model             | HFAutoModelForTextPrediction | 108 M  | train
1 | validation_metric | MulticlassAccuracy           | 0      | train
2 | loss_func         | CrossEntropyLoss             | 0      | train
---------------------------------------------------------------------------
108 M     Trainable params
0         Non-trainable params
108 M     Total params
435.579   Total estimated model params size (MB)
229       Modules in train mode
0         Modules in eval mode
/home/ci/opt/venv/lib/python3.11/site-packages/lightning/pytorch/loops/fit_loop.py:310: The number of training batches (4) is smaller than the logging interval Trainer(log_every_n_steps=10). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
INFO: Epoch 0, global step 1: 'val_accuracy' reached 0.37500 (best 0.37500), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20250508_205800/epoch=0-step=1.ckpt' as top 3
INFO: Epoch 1, global step 2: 'val_accuracy' reached 0.50000 (best 0.50000), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20250508_205800/epoch=1-step=2.ckpt' as top 3
INFO: Epoch 2, global step 3: 'val_accuracy' reached 0.37500 (best 0.50000), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20250508_205800/epoch=2-step=3.ckpt' as top 3
INFO: Epoch 3, global step 4: 'val_accuracy' reached 0.62500 (best 0.62500), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20250508_205800/epoch=3-step=4.ckpt' as top 3
INFO: Epoch 4, global step 5: 'val_accuracy' reached 0.62500 (best 0.62500), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20250508_205800/epoch=4-step=5.ckpt' as top 3
INFO: Epoch 5, global step 6: 'val_accuracy' reached 0.62500 (best 0.62500), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20250508_205800/epoch=5-step=6.ckpt' as top 3
INFO: Epoch 6, global step 7: 'val_accuracy' was not in top 3
INFO: Epoch 7, global step 8: 'val_accuracy' was not in top 3
INFO: Epoch 8, global step 9: 'val_accuracy' was not in top 3
Start to fuse 3 checkpoints via the greedy soup algorithm.
INFO: Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
INFO: Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
INFO: Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
AutoMM has created your model. 🎉🎉🎉

To load the model, use the code below:
    ```python
    from autogluon.multimodal import MultiModalPredictor
    predictor = MultiModalPredictor.load("/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20250508_205800")
    ```

If you are not satisfied with the model, try to increase the training time, 
adjust the hyperparameters (https://autogluon.cn/stable/tutorials/multimodal/advanced_topics/customization.html),
or post issues on GitHub (https://github.com/autogluon/autogluon/issues).
INFO: Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.

少样本图像分类

我们还提供了一个在少样本图像分类任务中使用 MultiModalPredictor 的示例。

加载数据集

我们使用 Stanford Cars 数据集进行演示,并将训练集下采样到每个类别有 8 个样本。Stanford Cars 是一个图像分类数据集,包含 196 个类别。有关数据集的更多信息,请参阅此处

import os
from autogluon.core.utils.loaders import load_zip, load_s3

download_dir = "./ag_automm_tutorial_fs_cls/stanfordcars/"
zip_file = "https://automl-mm-bench.s3.amazonaws.com/vision_datasets/stanfordcars/stanfordcars.zip"
train_csv = "https://automl-mm-bench.s3.amazonaws.com/vision_datasets/stanfordcars/train_8shot.csv"
test_csv = "https://automl-mm-bench.s3.amazonaws.com/vision_datasets/stanfordcars/test.csv"

load_zip.unzip(zip_file, unzip_dir=download_dir)
dataset_path = os.path.join(download_dir)
Downloading ./ag_automm_tutorial_fs_cls/stanfordcars//file.zip from https://automl-mm-bench.s3.amazonaws.com/vision_datasets/stanfordcars/stanfordcars.zip...
  0%|          | 0.00/1.96G [00:00<?, ?iB/s]
  0%|          | 8.38M/1.96G [00:00<00:29, 67.0MiB/s]
  1%|          | 15.8M/1.96G [00:00<00:27, 70.9MiB/s]
  1%|          | 22.9M/1.96G [00:00<00:40, 47.4MiB/s]
  1%|▏         | 29.0M/1.96G [00:00<00:37, 51.1MiB/s]
  2%|▏         | 34.6M/1.96G [00:00<01:01, 31.4MiB/s]
  2%|▏         | 40.2M/1.96G [00:00<00:54, 35.3MiB/s]
  2%|▏         | 44.6M/1.96G [00:01<01:02, 30.6MiB/s]
  3%|▎         | 50.3M/1.96G [00:01<01:01, 31.1MiB/s]
  3%|▎         | 57.3M/1.96G [00:01<01:06, 28.4MiB/s]
  3%|▎         | 60.5M/1.96G [00:01<01:08, 27.6MiB/s]
  3%|▎         | 67.1M/1.96G [00:01<01:00, 31.5MiB/s]
  4%|▍         | 75.5M/1.96G [00:02<00:52, 35.8MiB/s]
  4%|▍         | 83.9M/1.96G [00:02<00:51, 36.1MiB/s]
  5%|▍         | 91.5M/1.96G [00:02<00:43, 43.4MiB/s]
  5%|▍         | 96.4M/1.96G [00:02<00:50, 36.7MiB/s]
  5%|▌         | 101M/1.96G [00:02<00:56, 32.7MiB/s]
  6%|▌         | 108M/1.96G [00:02<00:45, 40.9MiB/s]
  6%|▌         | 113M/1.96G [00:03<00:47, 38.6MiB/s]
  6%|▌         | 117M/1.96G [00:03<00:46, 39.2MiB/s]
  6%|▋         | 124M/1.96G [00:03<00:41, 44.3MiB/s]
  7%|▋         | 128M/1.96G [00:03<00:45, 40.4MiB/s]
  7%|▋         | 134M/1.96G [00:03<00:50, 36.0MiB/s]
  7%|▋         | 142M/1.96G [00:03<00:40, 45.2MiB/s]
  8%|▊         | 147M/1.96G [00:03<00:41, 44.0MiB/s]
  8%|▊         | 152M/1.96G [00:03<00:39, 45.4MiB/s]
  8%|▊         | 158M/1.96G [00:04<00:40, 44.4MiB/s]
  8%|▊         | 163M/1.96G [00:04<00:42, 41.9MiB/s]
  9%|▊         | 168M/1.96G [00:04<00:39, 45.2MiB/s]
  9%|▉         | 175M/1.96G [00:04<00:47, 37.8MiB/s]
  9%|▉         | 179M/1.96G [00:04<00:50, 35.0MiB/s]
 10%|▉         | 188M/1.96G [00:04<00:36, 47.9MiB/s]
 10%|▉         | 194M/1.96G [00:05<00:52, 33.7MiB/s]
 10%|█         | 201M/1.96G [00:05<00:47, 36.9MiB/s]
 11%|█         | 210M/1.96G [00:05<00:43, 40.5MiB/s]
 11%|█         | 217M/1.96G [00:05<00:43, 39.7MiB/s]
 11%|█▏        | 221M/1.96G [00:05<00:48, 36.2MiB/s]
 12%|█▏        | 226M/1.96G [00:06<00:53, 32.1MiB/s]
 12%|█▏        | 233M/1.96G [00:06<00:47, 36.5MiB/s]
 12%|█▏        | 237M/1.96G [00:06<00:56, 30.5MiB/s]
 12%|█▏        | 241M/1.96G [00:06<00:54, 31.6MiB/s]
 13%|█▎        | 245M/1.96G [00:06<01:02, 27.4MiB/s]
 13%|█▎        | 250M/1.96G [00:06<00:55, 31.0MiB/s]
 13%|█▎        | 253M/1.96G [00:06<01:00, 28.0MiB/s]
 13%|█▎        | 259M/1.96G [00:07<00:50, 33.7MiB/s]
 13%|█▎        | 262M/1.96G [00:07<00:58, 28.9MiB/s]
 14%|█▎        | 267M/1.96G [00:07<00:56, 30.0MiB/s]
 14%|█▍        | 270M/1.96G [00:07<01:04, 26.2MiB/s]
 14%|█▍        | 277M/1.96G [00:07<00:55, 30.5MiB/s]
 14%|█▍        | 284M/1.96G [00:07<00:44, 37.8MiB/s]
 15%|█▍        | 288M/1.96G [00:07<00:48, 34.2MiB/s]
 15%|█▍        | 294M/1.96G [00:08<00:48, 34.2MiB/s]
 15%|█▌        | 300M/1.96G [00:08<00:47, 34.9MiB/s]
 16%|█▌        | 304M/1.96G [00:08<00:50, 32.5MiB/s]
 16%|█▌        | 309M/1.96G [00:08<00:45, 36.3MiB/s]
 16%|█▌        | 313M/1.96G [00:08<00:56, 29.2MiB/s]
 16%|█▋        | 319M/1.96G [00:08<00:50, 32.2MiB/s]
 17%|█▋        | 326M/1.96G [00:09<00:46, 35.1MiB/s]
 17%|█▋        | 329M/1.96G [00:09<00:52, 31.3MiB/s]
 17%|█▋        | 336M/1.96G [00:09<00:49, 33.0MiB/s]
 18%|█▊        | 344M/1.96G [00:09<00:40, 40.2MiB/s]
 18%|█▊        | 352M/1.96G [00:09<00:38, 42.2MiB/s]
 18%|█▊        | 361M/1.96G [00:09<00:34, 46.6MiB/s]
 19%|█▉        | 367M/1.96G [00:10<00:36, 44.0MiB/s]
 19%|█▉        | 372M/1.96G [00:10<00:40, 39.5MiB/s]
 19%|█▉        | 378M/1.96G [00:10<00:36, 43.7MiB/s]
 20%|█▉        | 384M/1.96G [00:10<00:38, 41.2MiB/s]
 20%|█▉        | 388M/1.96G [00:10<00:42, 36.6MiB/s]
 20%|██        | 394M/1.96G [00:10<00:44, 35.5MiB/s]
 21%|██        | 403M/1.96G [00:11<00:46, 33.5MiB/s]
 21%|██        | 411M/1.96G [00:11<00:39, 39.4MiB/s]
 21%|██▏       | 418M/1.96G [00:11<00:34, 44.0MiB/s]
 22%|██▏       | 422M/1.96G [00:11<00:38, 40.0MiB/s]
 22%|██▏       | 427M/1.96G [00:11<00:39, 38.4MiB/s]
 22%|██▏       | 431M/1.96G [00:11<00:44, 34.5MiB/s]
 22%|██▏       | 436M/1.96G [00:12<00:50, 29.8MiB/s]
 23%|██▎       | 443M/1.96G [00:12<00:42, 35.6MiB/s]
 23%|██▎       | 447M/1.96G [00:12<00:47, 32.0MiB/s]
 23%|██▎       | 451M/1.96G [00:12<00:44, 33.6MiB/s]
 23%|██▎       | 455M/1.96G [00:12<00:49, 30.4MiB/s]
 24%|██▎       | 461M/1.96G [00:12<00:46, 32.5MiB/s]
 24%|██▍       | 468M/1.96G [00:12<00:45, 33.1MiB/s]
 24%|██▍       | 471M/1.96G [00:13<00:48, 30.8MiB/s]
 24%|██▍       | 476M/1.96G [00:13<00:42, 34.7MiB/s]
 25%|██▍       | 480M/1.96G [00:13<00:44, 33.0MiB/s]
 25%|██▍       | 487M/1.96G [00:13<00:42, 34.8MiB/s]
 25%|██▌       | 493M/1.96G [00:13<00:38, 37.6MiB/s]
 25%|██▌       | 497M/1.96G [00:13<00:45, 32.0MiB/s]
 26%|██▌       | 503M/1.96G [00:13<00:39, 36.9MiB/s]
 26%|██▌       | 510M/1.96G [00:14<00:35, 41.1MiB/s]
 26%|██▋       | 514M/1.96G [00:14<00:42, 34.2MiB/s]
 26%|██▋       | 518M/1.96G [00:14<00:43, 32.9MiB/s]
 27%|██▋       | 522M/1.96G [00:14<00:52, 27.6MiB/s]
 27%|██▋       | 528M/1.96G [00:14<00:48, 29.6MiB/s]
 27%|██▋       | 535M/1.96G [00:15<00:48, 29.6MiB/s]
 27%|██▋       | 538M/1.96G [00:15<00:50, 28.0MiB/s]
 28%|██▊       | 544M/1.96G [00:15<00:41, 34.0MiB/s]
 28%|██▊       | 548M/1.96G [00:15<00:42, 32.9MiB/s]
 28%|██▊       | 552M/1.96G [00:15<00:39, 35.3MiB/s]
 28%|██▊       | 556M/1.96G [00:15<00:47, 29.4MiB/s]
 29%|██▊       | 562M/1.96G [00:15<00:40, 34.7MiB/s]
 29%|██▉       | 570M/1.96G [00:15<00:34, 40.1MiB/s]
 30%|██▉       | 579M/1.96G [00:16<00:31, 43.9MiB/s]
 30%|██▉       | 586M/1.96G [00:16<00:33, 41.3MiB/s]
 30%|███       | 590M/1.96G [00:16<00:34, 39.3MiB/s]
 30%|███       | 596M/1.96G [00:16<00:39, 34.7MiB/s]
 31%|███       | 604M/1.96G [00:16<00:36, 36.6MiB/s]
 31%|███       | 611M/1.96G [00:17<00:46, 29.2MiB/s]
 31%|███▏      | 614M/1.96G [00:17<00:50, 26.5MiB/s]
 32%|███▏      | 619M/1.96G [00:17<00:43, 30.6MiB/s]
 32%|███▏      | 622M/1.96G [00:17<00:45, 29.3MiB/s]
 32%|███▏      | 629M/1.96G [00:17<00:38, 35.0MiB/s]
 33%|███▎      | 638M/1.96G [00:17<00:30, 43.6MiB/s]
 33%|███▎      | 646M/1.96G [00:17<00:24, 52.8MiB/s]
 33%|███▎      | 652M/1.96G [00:18<00:28, 46.5MiB/s]
 34%|███▎      | 657M/1.96G [00:18<00:36, 35.7MiB/s]
 34%|███▍      | 661M/1.96G [00:18<00:35, 36.3MiB/s]
 34%|███▍      | 665M/1.96G [00:18<00:42, 30.3MiB/s]
 34%|███▍      | 671M/1.96G [00:18<00:45, 28.1MiB/s]
 35%|███▍      | 679M/1.96G [00:19<00:39, 32.2MiB/s]
 35%|███▌      | 686M/1.96G [00:19<00:36, 34.7MiB/s]
 35%|███▌      | 690M/1.96G [00:19<00:38, 32.9MiB/s]
 36%|███▌      | 696M/1.96G [00:19<00:34, 36.6MiB/s]
 36%|███▌      | 705M/1.96G [00:19<00:35, 35.2MiB/s]
 36%|███▋      | 713M/1.96G [00:19<00:30, 41.1MiB/s]
 37%|███▋      | 721M/1.96G [00:20<00:27, 44.6MiB/s]
 37%|███▋      | 726M/1.96G [00:20<00:29, 41.2MiB/s]
 37%|███▋      | 733M/1.96G [00:20<00:25, 47.7MiB/s]
 38%|███▊      | 738M/1.96G [00:20<00:26, 46.5MiB/s]
 38%|███▊      | 743M/1.96G [00:20<00:33, 36.6MiB/s]
 38%|███▊      | 747M/1.96G [00:20<00:46, 26.1MiB/s]
 39%|███▊      | 755M/1.96G [00:21<00:41, 28.9MiB/s]
 39%|███▉      | 763M/1.96G [00:21<00:38, 30.9MiB/s]
 39%|███▉      | 772M/1.96G [00:21<00:34, 34.6MiB/s]
 40%|███▉      | 778M/1.96G [00:21<00:30, 39.2MiB/s]
 40%|███▉      | 783M/1.96G [00:21<00:32, 36.2MiB/s]
 40%|████      | 787M/1.96G [00:22<00:31, 36.7MiB/s]
 40%|████      | 791M/1.96G [00:22<00:36, 31.6MiB/s]
 41%|████      | 797M/1.96G [00:22<00:37, 30.9MiB/s]
 41%|████      | 805M/1.96G [00:22<00:32, 35.1MiB/s]
 41%|████▏     | 812M/1.96G [00:22<00:28, 39.6MiB/s]
 42%|████▏     | 816M/1.96G [00:22<00:32, 34.8MiB/s]
 42%|████▏     | 822M/1.96G [00:23<00:34, 33.4MiB/s]
 42%|████▏     | 830M/1.96G [00:23<00:26, 42.3MiB/s]
 43%|████▎     | 835M/1.96G [00:23<00:30, 37.2MiB/s]
 43%|████▎     | 839M/1.96G [00:23<00:32, 34.1MiB/s]
 43%|████▎     | 847M/1.96G [00:23<00:31, 34.7MiB/s]
 44%|████▎     | 854M/1.96G [00:23<00:28, 39.1MiB/s]
 44%|████▍     | 858M/1.96G [00:24<00:30, 35.8MiB/s]
 44%|████▍     | 865M/1.96G [00:24<00:25, 42.9MiB/s]
 45%|████▍     | 872M/1.96G [00:24<00:26, 40.2MiB/s]
 45%|████▍     | 881M/1.96G [00:24<00:23, 45.8MiB/s]
 45%|████▌     | 888M/1.96G [00:24<00:21, 49.3MiB/s]
 46%|████▌     | 893M/1.96G [00:24<00:24, 42.7MiB/s]
 46%|████▌     | 898M/1.96G [00:24<00:29, 35.8MiB/s]
 46%|████▋     | 906M/1.96G [00:25<00:24, 43.6MiB/s]
 47%|████▋     | 913M/1.96G [00:25<00:22, 46.4MiB/s]
 47%|████▋     | 918M/1.96G [00:25<00:23, 44.1MiB/s]
 47%|████▋     | 922M/1.96G [00:25<00:24, 42.8MiB/s]
 47%|████▋     | 927M/1.96G [00:25<00:26, 39.4MiB/s]
 48%|████▊     | 931M/1.96G [00:25<00:26, 38.3MiB/s]
 48%|████▊     | 935M/1.96G [00:25<00:31, 33.0MiB/s]
 48%|████▊     | 941M/1.96G [00:25<00:25, 39.6MiB/s]
 48%|████▊     | 948M/1.96G [00:26<00:25, 40.4MiB/s]
 49%|████▊     | 954M/1.96G [00:26<00:22, 45.2MiB/s]
 49%|████▉     | 959M/1.96G [00:26<00:27, 36.6MiB/s]
 49%|████▉     | 963M/1.96G [00:26<00:31, 31.1MiB/s]
 49%|████▉     | 967M/1.96G [00:26<00:42, 23.2MiB/s]
 50%|████▉     | 974M/1.96G [00:27<00:30, 32.2MiB/s]
 50%|█████     | 980M/1.96G [00:27<00:28, 34.0MiB/s]
 50%|█████     | 984M/1.96G [00:27<00:30, 32.4MiB/s]
 50%|█████     | 988M/1.96G [00:27<00:27, 34.7MiB/s]
 51%|█████     | 992M/1.96G [00:27<00:30, 31.5MiB/s]
 51%|█████     | 997M/1.96G [00:27<00:28, 33.4MiB/s]
 51%|█████     | 1.00G/1.96G [00:27<00:29, 32.6MiB/s]
 51%|█████▏    | 1.01G/1.96G [00:28<00:31, 29.9MiB/s]
 52%|█████▏    | 1.01G/1.96G [00:28<00:28, 32.7MiB/s]
 52%|█████▏    | 1.02G/1.96G [00:28<00:30, 30.5MiB/s]
 52%|█████▏    | 1.02G/1.96G [00:28<00:30, 30.3MiB/s]
 53%|█████▎    | 1.03G/1.96G [00:28<00:25, 36.9MiB/s]
 53%|█████▎    | 1.03G/1.96G [00:28<00:35, 26.0MiB/s]
 53%|█████▎    | 1.04G/1.96G [00:29<00:28, 31.9MiB/s]
 54%|█████▎    | 1.05G/1.96G [00:29<00:27, 32.9MiB/s]
 54%|█████▍    | 1.06G/1.96G [00:29<00:22, 40.5MiB/s]
 54%|█████▍    | 1.06G/1.96G [00:29<00:19, 45.7MiB/s]
 55%|█████▍    | 1.07G/1.96G [00:29<00:22, 39.7MiB/s]
 55%|█████▍    | 1.07G/1.96G [00:29<00:28, 31.5MiB/s]
 55%|█████▌    | 1.08G/1.96G [00:30<00:27, 31.7MiB/s]
 55%|█████▌    | 1.08G/1.96G [00:30<00:28, 31.2MiB/s]
 55%|█████▌    | 1.08G/1.96G [00:30<00:29, 29.7MiB/s]
 56%|█████▌    | 1.09G/1.96G [00:30<00:25, 34.4MiB/s]
 56%|█████▌    | 1.09G/1.96G [00:30<00:29, 29.0MiB/s]
 56%|█████▌    | 1.10G/1.96G [00:30<00:29, 29.1MiB/s]
 56%|█████▋    | 1.11G/1.96G [00:30<00:24, 34.8MiB/s]
 57%|█████▋    | 1.11G/1.96G [00:31<00:27, 31.2MiB/s]
 57%|█████▋    | 1.12G/1.96G [00:31<00:29, 28.9MiB/s]
 57%|█████▋    | 1.12G/1.96G [00:31<00:27, 30.4MiB/s]
 57%|█████▋    | 1.13G/1.96G [00:31<00:32, 25.9MiB/s]
 58%|█████▊    | 1.13G/1.96G [00:31<00:31, 26.0MiB/s]
 58%|█████▊    | 1.13G/1.96G [00:32<00:38, 21.2MiB/s]
 58%|█████▊    | 1.14G/1.96G [00:32<00:32, 25.5MiB/s]
 59%|█████▊    | 1.15G/1.96G [00:32<00:28, 28.3MiB/s]
 59%|█████▉    | 1.15G/1.96G [00:32<00:31, 25.8MiB/s]
 59%|█████▉    | 1.16G/1.96G [00:32<00:25, 31.7MiB/s]
 60%|█████▉    | 1.17G/1.96G [00:33<00:19, 41.3MiB/s]
 60%|█████▉    | 1.17G/1.96G [00:33<00:20, 38.8MiB/s]
 60%|██████    | 1.17G/1.96G [00:33<00:20, 38.2MiB/s]
 60%|██████    | 1.18G/1.96G [00:33<00:19, 39.5MiB/s]
 61%|██████    | 1.19G/1.96G [00:33<00:15, 48.4MiB/s]
 61%|██████    | 1.20G/1.96G [00:33<00:18, 41.1MiB/s]
 61%|██████▏   | 1.20G/1.96G [00:33<00:20, 37.3MiB/s]
 62%|██████▏   | 1.21G/1.96G [00:34<00:20, 37.5MiB/s]
 62%|██████▏   | 1.22G/1.96G [00:34<00:18, 40.0MiB/s]
 63%|██████▎   | 1.22G/1.96G [00:34<00:17, 41.0MiB/s]
 63%|██████▎   | 1.23G/1.96G [00:34<00:16, 44.8MiB/s]
 63%|██████▎   | 1.24G/1.96G [00:34<00:15, 47.6MiB/s]
 64%|██████▎   | 1.24G/1.96G [00:34<00:16, 42.7MiB/s]
 64%|██████▍   | 1.25G/1.96G [00:35<00:17, 39.5MiB/s]
 64%|██████▍   | 1.26G/1.96G [00:35<00:14, 48.1MiB/s]
 65%|██████▍   | 1.26G/1.96G [00:35<00:16, 41.9MiB/s]
 65%|██████▍   | 1.27G/1.96G [00:35<00:17, 40.1MiB/s]
 65%|██████▌   | 1.28G/1.96G [00:35<00:16, 41.9MiB/s]
 65%|██████▌   | 1.28G/1.96G [00:35<00:20, 33.7MiB/s]
 66%|██████▌   | 1.29G/1.96G [00:36<00:20, 32.3MiB/s]
 66%|██████▌   | 1.29G/1.96G [00:36<00:17, 38.1MiB/s]
 66%|██████▌   | 1.30G/1.96G [00:36<00:20, 32.9MiB/s]
 66%|██████▋   | 1.30G/1.96G [00:36<00:26, 25.2MiB/s]
 67%|██████▋   | 1.31G/1.96G [00:36<00:25, 25.5MiB/s]
 67%|██████▋   | 1.31G/1.96G [00:37<00:29, 21.6MiB/s]
 67%|██████▋   | 1.32G/1.96G [00:37<00:24, 26.5MiB/s]
 68%|██████▊   | 1.32G/1.96G [00:37<00:20, 31.5MiB/s]
 68%|██████▊   | 1.33G/1.96G [00:37<00:23, 26.5MiB/s]
 68%|██████▊   | 1.33G/1.96G [00:37<00:24, 25.2MiB/s]
 68%|██████▊   | 1.33G/1.96G [00:38<00:27, 23.1MiB/s]
 69%|██████▊   | 1.34G/1.96G [00:38<00:21, 28.0MiB/s]
 69%|██████▉   | 1.35G/1.96G [00:38<00:19, 30.7MiB/s]
 69%|██████▉   | 1.35G/1.96G [00:38<00:20, 29.0MiB/s]
 69%|██████▉   | 1.36G/1.96G [00:38<00:19, 31.3MiB/s]
 69%|██████▉   | 1.36G/1.96G [00:38<00:20, 28.9MiB/s]
 70%|██████▉   | 1.37G/1.96G [00:38<00:17, 34.4MiB/s]
 70%|███████   | 1.37G/1.96G [00:39<00:14, 39.8MiB/s]
 70%|███████   | 1.38G/1.96G [00:39<00:15, 37.0MiB/s]
 71%|███████   | 1.38G/1.96G [00:39<00:17, 33.2MiB/s]
 71%|███████   | 1.39G/1.96G [00:39<00:12, 43.7MiB/s]
 71%|███████▏  | 1.40G/1.96G [00:39<00:16, 34.1MiB/s]
 72%|███████▏  | 1.40G/1.96G [00:39<00:16, 33.7MiB/s]
 72%|███████▏  | 1.41G/1.96G [00:40<00:17, 30.9MiB/s]
 72%|███████▏  | 1.41G/1.96G [00:40<00:18, 30.1MiB/s]
 72%|███████▏  | 1.42G/1.96G [00:40<00:18, 28.4MiB/s]
 73%|███████▎  | 1.42G/1.96G [00:40<00:17, 31.0MiB/s]
 73%|███████▎  | 1.43G/1.96G [00:40<00:18, 28.2MiB/s]
 73%|███████▎  | 1.43G/1.96G [00:40<00:15, 34.5MiB/s]
 74%|███████▎  | 1.44G/1.96G [00:41<00:15, 32.4MiB/s]
 74%|███████▍  | 1.44G/1.96G [00:41<00:16, 30.5MiB/s]
 74%|███████▍  | 1.45G/1.96G [00:41<00:17, 29.2MiB/s]
 74%|███████▍  | 1.46G/1.96G [00:41<00:15, 31.5MiB/s]
 75%|███████▍  | 1.46G/1.96G [00:41<00:17, 28.7MiB/s]
 75%|███████▍  | 1.47G/1.96G [00:41<00:13, 35.7MiB/s]
 75%|███████▌  | 1.48G/1.96G [00:42<00:11, 43.2MiB/s]
 76%|███████▌  | 1.48G/1.96G [00:42<00:11, 41.1MiB/s]
 76%|███████▌  | 1.48G/1.96G [00:42<00:12, 38.0MiB/s]
 76%|███████▌  | 1.49G/1.96G [00:42<00:11, 41.8MiB/s]
 76%|███████▋  | 1.50G/1.96G [00:42<00:12, 35.8MiB/s]
 77%|███████▋  | 1.50G/1.96G [00:42<00:15, 28.8MiB/s]
 77%|███████▋  | 1.51G/1.96G [00:43<00:12, 35.3MiB/s]
 77%|███████▋  | 1.51G/1.96G [00:43<00:12, 34.4MiB/s]
 78%|███████▊  | 1.52G/1.96G [00:43<00:13, 32.6MiB/s]
 78%|███████▊  | 1.53G/1.96G [00:43<00:11, 37.9MiB/s]
 78%|███████▊  | 1.54G/1.96G [00:43<00:09, 43.4MiB/s]
 79%|███████▉  | 1.54G/1.96G [00:43<00:09, 44.7MiB/s]
 79%|███████▉  | 1.55G/1.96G [00:44<00:08, 48.3MiB/s]
 79%|███████▉  | 1.56G/1.96G [00:44<00:10, 37.7MiB/s]
 80%|███████▉  | 1.56G/1.96G [00:44<00:10, 39.5MiB/s]
 80%|████████  | 1.57G/1.96G [00:44<00:10, 36.0MiB/s]
 80%|████████  | 1.57G/1.96G [00:44<00:10, 35.2MiB/s]
 81%|████████  | 1.58G/1.96G [00:44<00:10, 35.2MiB/s]
 81%|████████  | 1.59G/1.96G [00:45<00:08, 42.2MiB/s]
 81%|████████▏ | 1.59G/1.96G [00:45<00:10, 36.1MiB/s]
 82%|████████▏ | 1.60G/1.96G [00:45<00:11, 30.2MiB/s]
 82%|████████▏ | 1.60G/1.96G [00:45<00:10, 34.7MiB/s]
 82%|████████▏ | 1.61G/1.96G [00:45<00:11, 31.0MiB/s]
 82%|████████▏ | 1.61G/1.96G [00:45<00:12, 27.6MiB/s]
 82%|████████▏ | 1.61G/1.96G [00:46<00:13, 25.7MiB/s]
 83%|████████▎ | 1.62G/1.96G [00:46<00:10, 31.4MiB/s]
 83%|████████▎ | 1.62G/1.96G [00:46<00:12, 27.0MiB/s]
 83%|████████▎ | 1.63G/1.96G [00:46<00:09, 34.1MiB/s]
 83%|████████▎ | 1.63G/1.96G [00:46<00:10, 30.5MiB/s]
 84%|████████▎ | 1.64G/1.96G [00:46<00:10, 31.9MiB/s]
 84%|████████▍ | 1.64G/1.96G [00:46<00:08, 37.9MiB/s]
 84%|████████▍ | 1.65G/1.96G [00:47<00:09, 33.9MiB/s]
 84%|████████▍ | 1.65G/1.96G [00:47<00:10, 28.6MiB/s]
 85%|████████▍ | 1.66G/1.96G [00:47<00:09, 30.4MiB/s]
 85%|████████▍ | 1.66G/1.96G [00:47<00:11, 26.8MiB/s]
 85%|████████▌ | 1.67G/1.96G [00:47<00:09, 29.2MiB/s]
 86%|████████▌ | 1.68G/1.96G [00:48<00:09, 31.3MiB/s]
 86%|████████▌ | 1.68G/1.96G [00:48<00:09, 27.9MiB/s]
 86%|████████▌ | 1.68G/1.96G [00:48<00:09, 29.7MiB/s]
 86%|████████▌ | 1.69G/1.96G [00:48<00:11, 24.3MiB/s]
 87%|████████▋ | 1.69G/1.96G [00:48<00:08, 31.3MiB/s]
 87%|████████▋ | 1.70G/1.96G [00:48<00:09, 28.1MiB/s]
 87%|████████▋ | 1.70G/1.96G [00:48<00:09, 28.0MiB/s]
 87%|████████▋ | 1.70G/1.96G [00:49<00:09, 25.6MiB/s]
 87%|████████▋ | 1.71G/1.96G [00:49<00:08, 28.4MiB/s]
 88%|████████▊ | 1.72G/1.96G [00:49<00:07, 31.3MiB/s]
 88%|████████▊ | 1.72G/1.96G [00:49<00:07, 33.0MiB/s]
 88%|████████▊ | 1.73G/1.96G [00:49<00:05, 38.9MiB/s]
 88%|████████▊ | 1.73G/1.96G [00:49<00:06, 35.8MiB/s]
 89%|████████▊ | 1.74G/1.96G [00:50<00:06, 32.0MiB/s]
 89%|████████▉ | 1.74G/1.96G [00:50<00:05, 40.9MiB/s]
 89%|████████▉ | 1.75G/1.96G [00:50<00:05, 39.1MiB/s]
 90%|████████▉ | 1.76G/1.96G [00:50<00:06, 33.2MiB/s]
 90%|████████▉ | 1.76G/1.96G [00:50<00:05, 33.0MiB/s]
 90%|█████████ | 1.77G/1.96G [00:50<00:05, 34.4MiB/s]
 91%|█████████ | 1.78G/1.96G [00:51<00:04, 36.7MiB/s]
 91%|█████████ | 1.78G/1.96G [00:51<00:05, 29.7MiB/s]
 91%|█████████▏| 1.79G/1.96G [00:51<00:05, 28.7MiB/s]
 92%|█████████▏| 1.79G/1.96G [00:51<00:04, 33.7MiB/s]
 92%|█████████▏| 1.80G/1.96G [00:51<00:05, 30.4MiB/s]
 92%|█████████▏| 1.80G/1.96G [00:52<00:05, 30.5MiB/s]
 92%|█████████▏| 1.81G/1.96G [00:52<00:04, 33.1MiB/s]
 93%|█████████▎| 1.81G/1.96G [00:52<00:04, 32.3MiB/s]
 93%|█████████▎| 1.82G/1.96G [00:52<00:04, 29.1MiB/s]
 93%|█████████▎| 1.82G/1.96G [00:52<00:05, 26.4MiB/s]
 93%|█████████▎| 1.83G/1.96G [00:53<00:05, 25.0MiB/s]
 94%|█████████▍| 1.84G/1.96G [00:53<00:05, 21.1MiB/s]
 94%|█████████▍| 1.84G/1.96G [00:53<00:05, 20.7MiB/s]
 94%|█████████▍| 1.84G/1.96G [00:53<00:05, 22.5MiB/s]
 94%|█████████▍| 1.85G/1.96G [00:53<00:05, 22.3MiB/s]
 95%|█████████▍| 1.85G/1.96G [00:54<00:03, 27.8MiB/s]
 95%|█████████▍| 1.86G/1.96G [00:54<00:04, 23.0MiB/s]
 95%|█████████▌| 1.86G/1.96G [00:54<00:04, 21.6MiB/s]
 95%|█████████▌| 1.86G/1.96G [00:54<00:04, 20.0MiB/s]
 96%|█████████▌| 1.87G/1.96G [00:54<00:03, 29.1MiB/s]
 96%|█████████▌| 1.87G/1.96G [00:54<00:03, 25.8MiB/s]
 96%|█████████▌| 1.88G/1.96G [00:55<00:02, 29.1MiB/s]
 96%|█████████▋| 1.89G/1.96G [00:55<00:02, 32.7MiB/s]
 97%|█████████▋| 1.89G/1.96G [00:55<00:02, 29.0MiB/s]
 97%|█████████▋| 1.90G/1.96G [00:55<00:02, 29.7MiB/s]
 97%|█████████▋| 1.90G/1.96G [00:55<00:01, 34.8MiB/s]
 98%|█████████▊| 1.91G/1.96G [00:56<00:01, 33.2MiB/s]
 98%|█████████▊| 1.91G/1.96G [00:56<00:01, 30.9MiB/s]
 98%|█████████▊| 1.92G/1.96G [00:56<00:01, 35.7MiB/s]
 98%|█████████▊| 1.92G/1.96G [00:56<00:00, 34.9MiB/s]
 98%|█████████▊| 1.93G/1.96G [00:56<00:00, 36.7MiB/s]
 99%|█████████▊| 1.93G/1.96G [00:56<00:00, 33.6MiB/s]
 99%|█████████▉| 1.94G/1.96G [00:56<00:00, 30.7MiB/s]
 99%|█████████▉| 1.94G/1.96G [00:57<00:00, 27.7MiB/s]
 99%|█████████▉| 1.94G/1.96G [00:57<00:00, 29.8MiB/s]
 99%|█████████▉| 1.95G/1.96G [00:57<00:00, 27.5MiB/s]
100%|█████████▉| 1.95G/1.96G [00:57<00:00, 32.4MiB/s]
100%|█████████▉| 1.96G/1.96G [00:57<00:00, 32.2MiB/s]
100%|██████████| 1.96G/1.96G [00:57<00:00, 34.0MiB/s]

Unzipping ./ag_automm_tutorial_fs_cls/stanfordcars//file.zip to ./ag_automm_tutorial_fs_cls/stanfordcars/
!wget https://automl-mm-bench.s3.amazonaws.com/vision_datasets/stanfordcars/train_8shot.csv -O ./ag_automm_tutorial_fs_cls/stanfordcars/train.csv
!wget https://automl-mm-bench.s3.amazonaws.com/vision_datasets/stanfordcars/test.csv -O ./ag_automm_tutorial_fs_cls/stanfordcars/test.csv
--2025-05-08 21:01:19--  https://automl-mm-bench.s3.amazonaws.com/vision_datasets/stanfordcars/train_8shot.csv
Resolving automl-mm-bench.s3.amazonaws.com (automl-mm-bench.s3.amazonaws.com)... 52.217.112.217, 16.15.193.45, 52.217.129.169, ...
Connecting to automl-mm-bench.s3.amazonaws.com (automl-mm-bench.s3.amazonaws.com)|52.217.112.217|:443... connected.
HTTP request sent, awaiting response...
200 OK
Length: 94879 (93K) [text/csv]
Saving to: ‘./ag_automm_tutorial_fs_cls/stanfordcars/train.csv’

./ag_automm_tutoria 100%[===================>]  92.66K  --.-KB/s    in 0.002s  

2025-05-08 21:01:19 (57.0 MB/s) - ‘./ag_automm_tutorial_fs_cls/stanfordcars/train.csv’ saved [94879/94879]
--2025-05-08 21:01:19--  https://automl-mm-bench.s3.amazonaws.com/vision_datasets/stanfordcars/test.csv
Resolving automl-mm-bench.s3.amazonaws.com (automl-mm-bench.s3.amazonaws.com)... 3.5.30.52, 54.231.140.105, 16.15.177.34, ...
Connecting to automl-mm-bench.s3.amazonaws.com (automl-mm-bench.s3.amazonaws.com)|3.5.30.52|:443... connected.
HTTP request sent, awaiting response...
200 OK
Length: 34472 (34K) [text/csv]
Saving to: ‘./ag_automm_tutorial_fs_cls/stanfordcars/test.csv’

./ag_automm_tutoria 100%[===================>]  33.66K  --.-KB/s    in 0.001s  

2025-05-08 21:01:19 (50.7 MB/s) - ‘./ag_automm_tutorial_fs_cls/stanfordcars/test.csv’ saved [34472/34472]
import pandas as pd
import os

train_df_raw = pd.read_csv(os.path.join(download_dir, "train.csv"))
train_df = train_df_raw.drop(
        columns=[
            "Source",
            "Confidence",
            "XMin",
            "XMax",
            "YMin",
            "YMax",
            "IsOccluded",
            "IsTruncated",
            "IsGroupOf",
            "IsDepiction",
            "IsInside",
        ]
    )
train_df["ImageID"] = download_dir + train_df["ImageID"].astype(str)


test_df_raw = pd.read_csv(os.path.join(download_dir, "test.csv"))
test_df = test_df_raw.drop(
        columns=[
            "Source",
            "Confidence",
            "XMin",
            "XMax",
            "YMin",
            "YMax",
            "IsOccluded",
            "IsTruncated",
            "IsGroupOf",
            "IsDepiction",
            "IsInside",
        ]
    )
test_df["ImageID"] = download_dir + test_df["ImageID"].astype(str)

print(os.path.exists(train_df.iloc[0]["ImageID"]))
print(train_df)
print(os.path.exists(test_df.iloc[0]["ImageID"]))
print(test_df)
True
                                                ImageID  LabelName
0     ./ag_automm_tutorial_fs_cls/stanfordcars/train...        147
1     ./ag_automm_tutorial_fs_cls/stanfordcars/train...        120
2     ./ag_automm_tutorial_fs_cls/stanfordcars/train...        147
3     ./ag_automm_tutorial_fs_cls/stanfordcars/train...        167
4     ./ag_automm_tutorial_fs_cls/stanfordcars/train...         73
...                                                 ...        ...
1563  ./ag_automm_tutorial_fs_cls/stanfordcars/train...        116
1564  ./ag_automm_tutorial_fs_cls/stanfordcars/train...         76
1565  ./ag_automm_tutorial_fs_cls/stanfordcars/train...        148
1566  ./ag_automm_tutorial_fs_cls/stanfordcars/train...        189
1567  ./ag_automm_tutorial_fs_cls/stanfordcars/train...        183

[1568 rows x 2 columns]
True
                                               ImageID  LabelName
0    ./ag_automm_tutorial_fs_cls/stanfordcars/test/...          0
1    ./ag_automm_tutorial_fs_cls/stanfordcars/test/...          0
2    ./ag_automm_tutorial_fs_cls/stanfordcars/test/...          0
3    ./ag_automm_tutorial_fs_cls/stanfordcars/test/...          1
4    ./ag_automm_tutorial_fs_cls/stanfordcars/test/...          1
..                                                 ...        ...
583  ./ag_automm_tutorial_fs_cls/stanfordcars/test/...        194
584  ./ag_automm_tutorial_fs_cls/stanfordcars/test/...        194
585  ./ag_automm_tutorial_fs_cls/stanfordcars/test/...        195
586  ./ag_automm_tutorial_fs_cls/stanfordcars/test/...        195
587  ./ag_automm_tutorial_fs_cls/stanfordcars/test/...        195

[588 rows x 2 columns]

训练少样本分类器

类似地,我们需要使用 few_shot_classification 问题类型初始化 MultiModalPredictor

from autogluon.multimodal import MultiModalPredictor

predictor_fs_image = MultiModalPredictor(
    problem_type="few_shot_classification",
    label="LabelName",  # column name of the label
    eval_metric="acc",
)
predictor_fs_image.fit(train_df)
scores = predictor_fs_image.evaluate(test_df, metrics=["acc", "f1_macro"])
print(scores)
{'acc': 0.8010204081632653, 'f1_macro': 0.7958697764820214}
No path specified. Models will be saved in: "AutogluonModels/ag-20250508_210119"
=================== System Info ===================
AutoGluon Version:  1.3.1b20250508
Python Version:     3.11.9
Operating System:   Linux
Platform Machine:   x86_64
Platform Version:   #1 SMP Wed Mar 12 14:53:59 UTC 2025
CPU Count:          8
Pytorch Version:    2.6.0+cu124
CUDA Version:       12.4
Memory Avail:       24.24 GB / 30.95 GB (78.3%)
Disk Space Avail:   171.34 GB / 255.99 GB (66.9%)
===================================================

AutoMM starts to create your model. ✨✨✨

To track the learning progress, you can open a terminal and launch Tensorboard:
    ```shell
    # Assume you have installed tensorboard
    tensorboard --logdir /home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20250508_210119
    ```
INFO: Seed set to 0
GPU Count: 1
GPU Count to be Used: 1
INFO: Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
AutoMM has created your model. 🎉🎉🎉

To load the model, use the code below:
    ```python
    from autogluon.multimodal import MultiModalPredictor
    predictor = MultiModalPredictor.load("/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20250508_210119")
    ```

If you are not satisfied with the model, try to increase the training time, 
adjust the hyperparameters (https://autogluon.cn/stable/tutorials/multimodal/advanced_topics/customization.html),
or post issues on GitHub (https://github.com/autogluon/autogluon/issues).
INFO: Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.

与默认分类器比较

我们也可以训练一个默认的图像分类器,并与少样本分类器进行比较。

from autogluon.multimodal import MultiModalPredictor

predictor_default_image = MultiModalPredictor(
    problem_type="classification",
    label="LabelName",  # column name of the label
    eval_metric="acc",
)
predictor_default_image.fit(train_data=train_df)
scores = predictor_default_image.evaluate(test_df, metrics=["acc", "f1_macro"])
print(scores)
{'acc': 0.5476190476190477, 'f1_macro': 0.5309022893206566}
No path specified. Models will be saved in: "AutogluonModels/ag-20250508_210252"
=================== System Info ===================
AutoGluon Version:  1.3.1b20250508
Python Version:     3.11.9
Operating System:   Linux
Platform Machine:   x86_64
Platform Version:   #1 SMP Wed Mar 12 14:53:59 UTC 2025
CPU Count:          8
Pytorch Version:    2.6.0+cu124
CUDA Version:       12.4
Memory Avail:       24.22 GB / 30.95 GB (78.3%)
Disk Space Avail:   167.32 GB / 255.99 GB (65.4%)
===================================================
AutoGluon infers your prediction problem is: 'regression' (because dtype of label-column == int and many unique label-values observed).
	Label info (max, min, mean, stddev): (195, 0, 97.5, 56.59764)
	If 'regression' is not the correct problem_type, please manually specify the problem_type parameter during Predictor init (You may specify problem_type as one of: ['binary', 'multiclass', 'regression', 'quantile'])

AutoMM starts to create your model. ✨✨✨

To track the learning progress, you can open a terminal and launch Tensorboard:
    ```shell
    # Assume you have installed tensorboard
    tensorboard --logdir /home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20250508_210252
    ```
INFO: Seed set to 0
GPU Count: 1
GPU Count to be Used: 1
INFO: Using 16bit Automatic Mixed Precision (AMP)
INFO: GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO: HPU available: False, using: 0 HPUs
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO: 
  | Name              | Type                            | Params | Mode 
------------------------------------------------------------------------------
0 | model             | TimmAutoModelForImagePrediction | 96.3 M | train
1 | validation_metric | MulticlassAccuracy              | 0      | train
2 | loss_func         | CrossEntropyLoss                | 0      | train
------------------------------------------------------------------------------
96.3 M    Trainable params
0         Non-trainable params
96.3 M    Total params
385.132   Total estimated model params size (MB)
863       Modules in train mode
0         Modules in eval mode
INFO: Epoch 0, global step 4: 'val_accuracy' reached 0.00000 (best 0.00000), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20250508_210252/epoch=0-step=4.ckpt' as top 3
INFO: Epoch 0, global step 9: 'val_accuracy' reached 0.00318 (best 0.00318), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20250508_210252/epoch=0-step=9.ckpt' as top 3
INFO: Epoch 1, global step 14: 'val_accuracy' reached 0.01592 (best 0.01592), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20250508_210252/epoch=1-step=14.ckpt' as top 3
INFO: Epoch 1, global step 19: 'val_accuracy' reached 0.05096 (best 0.05096), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20250508_210252/epoch=1-step=19.ckpt' as top 3
INFO: Epoch 2, global step 24: 'val_accuracy' reached 0.10510 (best 0.10510), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20250508_210252/epoch=2-step=24.ckpt' as top 3
INFO: Epoch 2, global step 29: 'val_accuracy' reached 0.12739 (best 0.12739), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20250508_210252/epoch=2-step=29.ckpt' as top 3
INFO: Epoch 3, global step 34: 'val_accuracy' reached 0.16561 (best 0.16561), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20250508_210252/epoch=3-step=34.ckpt' as top 3
INFO: Epoch 3, global step 39: 'val_accuracy' reached 0.21019 (best 0.21019), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20250508_210252/epoch=3-step=39.ckpt' as top 3
INFO: Epoch 4, global step 44: 'val_accuracy' reached 0.26752 (best 0.26752), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20250508_210252/epoch=4-step=44.ckpt' as top 3
INFO: Epoch 4, global step 49: 'val_accuracy' reached 0.29936 (best 0.29936), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20250508_210252/epoch=4-step=49.ckpt' as top 3
INFO: Epoch 5, global step 54: 'val_accuracy' reached 0.32803 (best 0.32803), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20250508_210252/epoch=5-step=54.ckpt' as top 3
INFO: Epoch 5, global step 59: 'val_accuracy' reached 0.34713 (best 0.34713), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20250508_210252/epoch=5-step=59.ckpt' as top 3
INFO: Epoch 6, global step 64: 'val_accuracy' reached 0.41720 (best 0.41720), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20250508_210252/epoch=6-step=64.ckpt' as top 3
INFO: Epoch 6, global step 69: 'val_accuracy' reached 0.41083 (best 0.41720), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20250508_210252/epoch=6-step=69.ckpt' as top 3
INFO: Epoch 7, global step 74: 'val_accuracy' reached 0.44586 (best 0.44586), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20250508_210252/epoch=7-step=74.ckpt' as top 3
INFO: Epoch 7, global step 79: 'val_accuracy' reached 0.45860 (best 0.45860), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20250508_210252/epoch=7-step=79.ckpt' as top 3
INFO: Epoch 8, global step 84: 'val_accuracy' reached 0.46178 (best 0.46178), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20250508_210252/epoch=8-step=84.ckpt' as top 3
INFO: Epoch 8, global step 89: 'val_accuracy' reached 0.47452 (best 0.47452), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20250508_210252/epoch=8-step=89.ckpt' as top 3
INFO: Epoch 9, global step 94: 'val_accuracy' reached 0.49045 (best 0.49045), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20250508_210252/epoch=9-step=94.ckpt' as top 3
INFO: Epoch 9, global step 99: 'val_accuracy' reached 0.50000 (best 0.50000), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20250508_210252/epoch=9-step=99.ckpt' as top 3
INFO: Epoch 10, global step 104: 'val_accuracy' reached 0.52866 (best 0.52866), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20250508_210252/epoch=10-step=104.ckpt' as top 3
INFO: Epoch 10, global step 109: 'val_accuracy' reached 0.51274 (best 0.52866), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20250508_210252/epoch=10-step=109.ckpt' as top 3
INFO: Epoch 11, global step 114: 'val_accuracy' reached 0.53503 (best 0.53503), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20250508_210252/epoch=11-step=114.ckpt' as top 3
INFO: Epoch 11, global step 119: 'val_accuracy' reached 0.53822 (best 0.53822), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20250508_210252/epoch=11-step=119.ckpt' as top 3
INFO: Epoch 12, global step 124: 'val_accuracy' reached 0.54140 (best 0.54140), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20250508_210252/epoch=12-step=124.ckpt' as top 3
INFO: Epoch 12, global step 129: 'val_accuracy' was not in top 3
INFO: Epoch 13, global step 134: 'val_accuracy' was not in top 3
INFO: Epoch 13, global step 139: 'val_accuracy' reached 0.54777 (best 0.54777), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20250508_210252/epoch=13-step=139.ckpt' as top 3
INFO: Epoch 14, global step 144: 'val_accuracy' was not in top 3
INFO: Epoch 14, global step 149: 'val_accuracy' was not in top 3
INFO: Epoch 15, global step 154: 'val_accuracy' was not in top 3
INFO: Epoch 15, global step 159: 'val_accuracy' was not in top 3
INFO: Epoch 16, global step 164: 'val_accuracy' reached 0.54777 (best 0.54777), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20250508_210252/epoch=16-step=164.ckpt' as top 3
INFO: Epoch 16, global step 169: 'val_accuracy' reached 0.54777 (best 0.54777), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20250508_210252/epoch=16-step=169.ckpt' as top 3
INFO: Epoch 17, global step 174: 'val_accuracy' was not in top 3
INFO: Epoch 17, global step 179: 'val_accuracy' was not in top 3
INFO: Epoch 18, global step 184: 'val_accuracy' was not in top 3
INFO: Epoch 18, global step 189: 'val_accuracy' was not in top 3
Start to fuse 3 checkpoints via the greedy soup algorithm.
INFO: Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
INFO: Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
INFO: Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
AutoMM has created your model. 🎉🎉🎉

To load the model, use the code below:
    ```python
    from autogluon.multimodal import MultiModalPredictor
    predictor = MultiModalPredictor.load("/home/ci/autogluon/docs/tutorials/multimodal/advanced_topics/AutogluonModels/ag-20250508_210252")
    ```

If you are not satisfied with the model, try to increase the training time, 
adjust the hyperparameters (https://autogluon.cn/stable/tutorials/multimodal/advanced_topics/customization.html),
or post issues on GitHub (https://github.com/autogluon/autogluon/issues).
INFO: Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.

正如您所见,few_shot_classification 在图像分类中的表现也比默认的 classification 要好得多。

自定义

要了解如何自定义 AutoMM,请参阅自定义 AutoMM