使用 AutoMM 进行图像-文本语义匹配 - 零样本

Open In Colab Open In SageMaker Studio Lab

图像-文本语义匹配任务是指衡量图像与句子之间的视觉语义相似度。AutoMM 通过利用强大的 CLIP 支持零样本图像-文本匹配。得益于对比损失目标并在数百万图像-文本对上进行训练,CLIP 为视觉和语言以及它们之间的连接学习了很好的嵌入。因此,我们可以使用它来提取嵌入以进行检索和匹配。

CLIP 采用双塔结构,这意味着它有两个编码器:一个用于图像,另一个用于文本。CLIP 模型概览见下图。左侧显示其预训练阶段,右侧显示其零样本预测阶段。通过计算一个图像嵌入与所有文本图像之间的余弦相似度分数,我们选择相似度最高的文本作为预测结果。

鉴于这两个编码器,我们可以提取图像嵌入或文本嵌入。最重要的是,嵌入提取可以离线完成,只有相似度计算需要在线完成。这意味着良好的可伸缩性。CLIP

在本教程中,我们将展示 AutoMM 易于使用的 API 如何将强大的 CLIP 带给您。

准备演示数据

首先,让我们获取一些文本并下载一些图像。这些图像来自 COCO 数据集

from autogluon.multimodal import download

texts = [
    "A cheetah chases prey on across a field.",
    "A man is eating a piece of bread.",
    "The girl is carrying a baby.",
    "There is an airplane over a car.",
    "A man is riding a horse.",
    "Two men pushed carts through the woods.",
    "There is a carriage in the image.",
    "A man is riding a white horse on an enclosed ground.",
    "A monkey is playing drums.",
]

urls = ['http://farm4.staticflickr.com/3179/2872917634_f41e6987a8_z.jpg',
        'http://farm4.staticflickr.com/3629/3608371042_75f9618851_z.jpg',
        'https://farm4.staticflickr.com/3795/9591251800_9c9727e178_z.jpg',
        'http://farm8.staticflickr.com/7188/6848765123_252bfca33d_z.jpg',
        'https://farm6.staticflickr.com/5251/5548123650_1a69ce1e34_z.jpg']

image_paths = [download(url) for url in urls]
---------------------------------------------------------------------------
ImportError                               Traceback (most recent call last)
Cell In[2], line 1
----> 1 from autogluon.multimodal import download
      3 texts = [
      4     "A cheetah chases prey on across a field.",
      5     "A man is eating a piece of bread.",
   (...)
     12     "A monkey is playing drums.",
     13 ]
     15 urls = ['http://farm4.staticflickr.com/3179/2872917634_f41e6987a8_z.jpg',
     16         'http://farm4.staticflickr.com/3629/3608371042_75f9618851_z.jpg',
     17         'https://farm4.staticflickr.com/3795/9591251800_9c9727e178_z.jpg',
     18         'http://farm8.staticflickr.com/7188/6848765123_252bfca33d_z.jpg',
     19         'https://farm6.staticflickr.com/5251/5548123650_1a69ce1e34_z.jpg']

ImportError: cannot import name 'download' from 'autogluon.multimodal' (/home/ci/autogluon/multimodal/src/autogluon/multimodal/__init__.py)

提取嵌入

在初始化预测器时,我们需要使用 image_text_similarity 作为问题类型。

from autogluon.multimodal import MultiModalPredictor
predictor = MultiModalPredictor(problem_type="image_text_similarity")

让我们分别提取图像嵌入和文本嵌入。图像和文本数据将分别通过其对应的编码器。

image_embeddings = predictor.extract_embedding(image_paths, as_tensor=True)
print(image_embeddings.shape)
text_embeddings = predictor.extract_embedding(texts, as_tensor=True)
print(text_embeddings.shape)

然后,您可以使用这些嵌入来完成一系列任务,例如图像检索和文本检索。

使用文本查询进行图像检索

假设我们有一个大型图像数据库(例如,视频片段),现在我们想检索由文本查询定义的图像。如何做到这一点?

很简单。首先,如上所示离线提取所有图像嵌入。然后,提取文本查询的嵌入。最后,计算文本嵌入与所有图像嵌入之间的余弦相似度,并返回最相关的候选结果。

假设我们使用下面的文本作为查询。

print(texts[6])

您可以直接调用我们的实用函数 semantic_search 来搜索语义相似的图像。

from autogluon.multimodal.utils import semantic_search
hits = semantic_search(
        matcher=predictor,
        query_embeddings=text_embeddings[6][None,],
        response_embeddings=image_embeddings,
        top_k=5,
    )
print(hits)

我们可以看到成功找到了包含马车的图像。

from IPython.display import Image, display
pil_img = Image(filename=image_paths[hits[0][0]["response_id"]])
display(pil_img)

使用图像查询进行文本检索

类似地,给定一个文本数据库和一个图像查询,我们可以搜索与图像匹配的文本。例如,让我们搜索以下图像的文本。

pil_img = Image(filename=image_paths[4])
display(pil_img)

我们仍然使用 semantic_search 函数,但切换 query_embeddingsresponse_embeddings 的赋值。

hits = semantic_search(
        matcher=predictor,
        query_embeddings=image_embeddings[4][None,],
        response_embeddings=text_embeddings,
        top_k=5,
    )
print(hits)

我们可以观察到 top-1 文本与查询图像匹配。

texts[hits[0][0]["response_id"]]

预测图像-文本对是否匹配

除了检索,我们可以让预测器判断图像-文本对是否匹配。为此,我们需要使用附加参数 queryresponse 初始化预测器,这些参数表示图像/文本和文本/图像的名称。

predictor = MultiModalPredictor(
            query="abc",
            response="xyz",
            problem_type="image_text_similarity",
        )

给定图像-文本对,我们可以进行预测。

pred = predictor.predict({"abc": [image_paths[4]], "xyz": [texts[3]]})
print(pred)

预测匹配概率

预测匹配概率也很容易。您可以通过对概率应用自定义阈值来进行预测。

proba = predictor.predict_proba({"abc": [image_paths[4]], "xyz": [texts[3]]})
print(proba)

其他示例

您可以前往 AutoMM 示例 探索更多关于 AutoMM 的示例。

自定义

要了解如何自定义 AutoMM,请参考 自定义 AutoMM