使用 AutoMM 进行文本到文本语义匹配¶
计算两个句子/段落之间的相似度是自然语言处理(NLP)中的一项常见任务,具有多种实际应用,例如网络搜索、问答、文档去重、抄袭对比、自然语言推理、推荐引擎等。通常,文本相似度模型将两个句子/段落作为输入,并将其转换为向量,然后使用余弦相似度、点积或欧几里得距离计算相似度分数来衡量两个文本片段的相似或不同程度。
准备数据¶
在本教程中,我们将演示如何使用 AutoMM 通过斯坦福自然语言推理(SNLI)语料库进行文本到文本语义匹配。SNLI 语料库包含约 57 万个人工编写的句子对,标记有 蕴含(entailment)、 矛盾(contradiction)和 中立(neutral)。它是评估机器学习方法表示和推理能力的广泛使用的基准。下表包含来自该语料库的三个示例。
前提 |
假设 |
标签 |
---|---|---|
一辆黑色赛车在一群人面前启动。 |
一个人正驾车行驶在一条孤独的路上。 |
矛盾 |
一老一少两个男人在微笑。 |
两个男人微笑着,并对着在地板上玩耍的猫咪大笑。 |
中立 |
一场有多个男性参与的足球比赛。 |
一些男人正在进行一项运动。 |
蕴含 |
在这里,我们将标记为 蕴含(entailment)的句子对视为正例(标记为 1),将标记为 矛盾(contradiction)的句子对视为负例(标记为 0)。具有中立关系的句子对被丢弃。以下代码下载语料库并将其加载到数据框中。
from autogluon.core.utils.loaders import load_pd
import pandas as pd
snli_train = load_pd.load('https://automl-mm-bench.s3.amazonaws.com/snli/snli_train.csv', delimiter="|")
snli_test = load_pd.load('https://automl-mm-bench.s3.amazonaws.com/snli/snli_test.csv', delimiter="|")
snli_train.head()
前提 | 假设 | 标签 | |
---|---|---|---|
0 | 一个骑马的人跳过了一个坍塌的... | 一个人在一家餐馆点了一份煎蛋卷。 | 0 |
1 | 一个骑马的人跳过了一个坍塌的... | 一个人在户外,骑在一匹马上。 | 1 |
2 | 孩子们对着镜头微笑挥手 | 有孩子在场 | 1 |
3 | 孩子们对着镜头微笑挥手 | 孩子们在皱眉 | 0 |
4 | 一个男孩正在中间跳滑板... | 男孩踩着滑板沿着人行道滑下去。 | 0 |
训练模型¶
理想情况下,我们希望获得一个能够为正/负文本对返回高/低分数的模型。传统的文本相似度方法只在词法层面工作,不考虑语义方面,例如使用词频或 tf-idf 向量。使用 AutoMM,我们可以轻松训练一个捕捉句子间语义关系的模型。基本上,它使用 BERT 将每个句子投影到高维向量中,并遵循 sentence transformers 中的设计将匹配问题视为分类问题。使用 AutoMM,您只需指定查询列名、响应列名和标签列名,然后使用训练数据集拟合模型,无需担心实现细节。请注意,标签应该是二元的,并且我们需要指定 match_label
,这意味着两个句子具有相同的语义。在实践中,您的任务可能有不同的标签,例如重复或不重复。您可能需要根据您的具体任务上下文来定义 match_label
。
from autogluon.multimodal import MultiModalPredictor
# Initialize the model
predictor = MultiModalPredictor(
problem_type="text_similarity",
query="premise", # the column name of the first sentence
response="hypothesis", # the column name of the second sentence
label="label", # the label column name
match_label=1, # the label indicating that query and response have the same semantic meanings.
eval_metric='auc', # the evaluation metric
)
# Fit the model
predictor.fit(
train_data=snli_train,
time_limit=180,
)
No path specified. Models will be saved in: "AutogluonModels/ag-20250508_211100"
=================== System Info ===================
AutoGluon Version: 1.3.1b20250508
Python Version: 3.11.9
Operating System: Linux
Platform Machine: x86_64
Platform Version: #1 SMP Wed Mar 12 14:53:59 UTC 2025
CPU Count: 8
Pytorch Version: 2.6.0+cu124
CUDA Version: 12.4
Memory Avail: 28.17 GB / 30.95 GB (91.0%)
Disk Space Avail: 168.56 GB / 255.99 GB (65.8%)
===================================================
AutoGluon infers your prediction problem is: 'binary' (because only two unique label-values observed).
2 unique label values: [np.int64(0), np.int64(1)]
If 'binary' is not the correct problem_type, please manually specify the problem_type parameter during Predictor init (You may specify problem_type as one of: ['binary', 'multiclass', 'regression', 'quantile'])
/home/ci/autogluon/multimodal/src/autogluon/multimodal/optim/metrics/utils.py:185: UserWarning: Metric auc is not supported as the evaluation metric for binary in matching tasks.The evaluation metric is changed to roc_auc by default.
warnings.warn(
AutoMM starts to create your model. ✨✨✨
To track the learning progress, you can open a terminal and launch Tensorboard:
```shell
# Assume you have installed tensorboard
tensorboard --logdir /home/ci/autogluon/docs/tutorials/multimodal/semantic_matching/AutogluonModels/ag-20250508_211100
```
Seed set to 0
GPU Count: 1
GPU Count to be Used: 1
Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
| Name | Type | Params | Mode
---------------------------------------------------------------------------
0 | query_model | HFAutoModelForTextPrediction | 33.4 M | train
1 | response_model | HFAutoModelForTextPrediction | 33.4 M | train
2 | validation_metric | BinaryAUROC | 0 | train
3 | loss_func | ContrastiveLoss | 0 | train
4 | miner_func | PairMarginMiner | 0 | train
---------------------------------------------------------------------------
33.4 M Trainable params
0 Non-trainable params
33.4 M Total params
133.440 Total estimated model params size (MB)
241 Modules in train mode
0 Modules in eval mode
Time limit reached. Elapsed time is 0:03:00. Signaling Trainer to stop.
Epoch 0, global step 180: 'val_roc_auc' reached 0.89562 (best 0.89562), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/semantic_matching/AutogluonModels/ag-20250508_211100/epoch=0-step=180.ckpt' as top 3
Start to fuse 1 checkpoints via the greedy soup algorithm.
Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
AutoMM has created your model. 🎉🎉🎉
To load the model, use the code below:
```python
from autogluon.multimodal import MultiModalPredictor
predictor = MultiModalPredictor.load("/home/ci/autogluon/docs/tutorials/multimodal/semantic_matching/AutogluonModels/ag-20250508_211100")
```
If you are not satisfied with the model, try to increase the training time,
adjust the hyperparameters (https://autogluon.cn/stable/tutorials/multimodal/advanced_topics/customization.html),
or post issues on GitHub (https://github.com/autogluon/autogluon/issues).
<autogluon.multimodal.predictor.MultiModalPredictor at 0x7f90dd78a090>
在测试数据集上进行评估¶
您可以在测试数据集上评估匹配器,以查看其在 roc_auc 分数上的表现
score = predictor.evaluate(snli_test)
print("evaluation score: ", score)
evaluation score: {'roc_auc': np.float64(0.9104766407123103)}
Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
对新的句子对进行预测¶
我们创建一个含义相似的新句子对(预计预测为 \(1\)),并使用训练好的模型进行预测。
pred_data = pd.DataFrame.from_dict({"premise":["The teacher gave his speech to an empty room."],
"hypothesis":["There was almost nobody when the professor was talking."]})
predictions = predictor.predict(pred_data)
print('Predicted entities:', predictions[0])
Predicted entities: 1
Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
预测匹配概率¶
我们还可以计算句子对的匹配概率。
probabilities = predictor.predict_proba(pred_data)
print(probabilities)
0 1
0 0.205128 0.794872
Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
提取嵌入¶
此外,我们支持分别提取两组句子的嵌入。
embeddings_1 = predictor.extract_embedding({"premise":["The teacher gave his speech to an empty room."]})
print(embeddings_1.shape)
embeddings_2 = predictor.extract_embedding({"hypothesis":["There was almost nobody when the professor was talking."]})
print(embeddings_2.shape)
(1, 384)
(1, 384)
Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
其他示例¶
您可以前往 AutoMM 示例,探索关于 AutoMM 的其他示例。
定制¶
要了解如何定制 AutoMM,请参阅定制 AutoMM。