使用 AutoMM 进行图像-文本语义匹配 - 零样本¶
图像-文本语义匹配任务是指衡量图像与句子之间的视觉语义相似度。AutoMM 通过利用强大的 CLIP 支持零样本图像-文本匹配。得益于对比损失目标并在数百万图像-文本对上进行训练,CLIP 为视觉和语言以及它们之间的连接学习了很好的嵌入。因此,我们可以使用它来提取嵌入以进行检索和匹配。
CLIP 采用双塔结构,这意味着它有两个编码器:一个用于图像,另一个用于文本。CLIP 模型概览见下图。左侧显示其预训练阶段,右侧显示其零样本预测阶段。通过计算一个图像嵌入与所有文本图像之间的余弦相似度分数,我们选择相似度最高的文本作为预测结果。
鉴于这两个编码器,我们可以提取图像嵌入或文本嵌入。最重要的是,嵌入提取可以离线完成,只有相似度计算需要在线完成。这意味着良好的可伸缩性。
在本教程中,我们将展示 AutoMM 易于使用的 API 如何将强大的 CLIP 带给您。
准备演示数据¶
首先,让我们获取一些文本并下载一些图像。这些图像来自 COCO 数据集。
from autogluon.multimodal import download
texts = [
"A cheetah chases prey on across a field.",
"A man is eating a piece of bread.",
"The girl is carrying a baby.",
"There is an airplane over a car.",
"A man is riding a horse.",
"Two men pushed carts through the woods.",
"There is a carriage in the image.",
"A man is riding a white horse on an enclosed ground.",
"A monkey is playing drums.",
]
urls = ['http://farm4.staticflickr.com/3179/2872917634_f41e6987a8_z.jpg',
'http://farm4.staticflickr.com/3629/3608371042_75f9618851_z.jpg',
'https://farm4.staticflickr.com/3795/9591251800_9c9727e178_z.jpg',
'http://farm8.staticflickr.com/7188/6848765123_252bfca33d_z.jpg',
'https://farm6.staticflickr.com/5251/5548123650_1a69ce1e34_z.jpg']
image_paths = [download(url) for url in urls]
---------------------------------------------------------------------------
ImportError Traceback (most recent call last)
Cell In[2], line 1
----> 1 from autogluon.multimodal import download
3 texts = [
4 "A cheetah chases prey on across a field.",
5 "A man is eating a piece of bread.",
(...)
12 "A monkey is playing drums.",
13 ]
15 urls = ['http://farm4.staticflickr.com/3179/2872917634_f41e6987a8_z.jpg',
16 'http://farm4.staticflickr.com/3629/3608371042_75f9618851_z.jpg',
17 'https://farm4.staticflickr.com/3795/9591251800_9c9727e178_z.jpg',
18 'http://farm8.staticflickr.com/7188/6848765123_252bfca33d_z.jpg',
19 'https://farm6.staticflickr.com/5251/5548123650_1a69ce1e34_z.jpg']
ImportError: cannot import name 'download' from 'autogluon.multimodal' (/home/ci/autogluon/multimodal/src/autogluon/multimodal/__init__.py)
提取嵌入¶
在初始化预测器时,我们需要使用 image_text_similarity
作为问题类型。
from autogluon.multimodal import MultiModalPredictor
predictor = MultiModalPredictor(problem_type="image_text_similarity")
让我们分别提取图像嵌入和文本嵌入。图像和文本数据将分别通过其对应的编码器。
image_embeddings = predictor.extract_embedding(image_paths, as_tensor=True)
print(image_embeddings.shape)
text_embeddings = predictor.extract_embedding(texts, as_tensor=True)
print(text_embeddings.shape)
然后,您可以使用这些嵌入来完成一系列任务,例如图像检索和文本检索。
使用文本查询进行图像检索¶
假设我们有一个大型图像数据库(例如,视频片段),现在我们想检索由文本查询定义的图像。如何做到这一点?
很简单。首先,如上所示离线提取所有图像嵌入。然后,提取文本查询的嵌入。最后,计算文本嵌入与所有图像嵌入之间的余弦相似度,并返回最相关的候选结果。
假设我们使用下面的文本作为查询。
print(texts[6])
您可以直接调用我们的实用函数 semantic_search
来搜索语义相似的图像。
from autogluon.multimodal.utils import semantic_search
hits = semantic_search(
matcher=predictor,
query_embeddings=text_embeddings[6][None,],
response_embeddings=image_embeddings,
top_k=5,
)
print(hits)
我们可以看到成功找到了包含马车的图像。
from IPython.display import Image, display
pil_img = Image(filename=image_paths[hits[0][0]["response_id"]])
display(pil_img)
使用图像查询进行文本检索¶
类似地,给定一个文本数据库和一个图像查询,我们可以搜索与图像匹配的文本。例如,让我们搜索以下图像的文本。
pil_img = Image(filename=image_paths[4])
display(pil_img)
我们仍然使用 semantic_search
函数,但切换 query_embeddings
和 response_embeddings
的赋值。
hits = semantic_search(
matcher=predictor,
query_embeddings=image_embeddings[4][None,],
response_embeddings=text_embeddings,
top_k=5,
)
print(hits)
我们可以观察到 top-1 文本与查询图像匹配。
texts[hits[0][0]["response_id"]]
预测图像-文本对是否匹配¶
除了检索,我们可以让预测器判断图像-文本对是否匹配。为此,我们需要使用附加参数 query
和 response
初始化预测器,这些参数表示图像/文本和文本/图像的名称。
predictor = MultiModalPredictor(
query="abc",
response="xyz",
problem_type="image_text_similarity",
)
给定图像-文本对,我们可以进行预测。
pred = predictor.predict({"abc": [image_paths[4]], "xyz": [texts[3]]})
print(pred)
预测匹配概率¶
预测匹配概率也很容易。您可以通过对概率应用自定义阈值来进行预测。
proba = predictor.predict_proba({"abc": [image_paths[4]], "xyz": [texts[3]]})
print(proba)
其他示例¶
您可以前往 AutoMM 示例 探索更多关于 AutoMM 的示例。
自定义¶
要了解如何自定义 AutoMM,请参考 自定义 AutoMM。