AutoMM 用于中文命名实体识别 - 快速入门

Open In Colab Open In SageMaker Studio Lab

在本教程中,我们将演示如何使用 AutoMM 对中文命名实体进行识别,数据集来自中国最受欢迎的在线市场之一 TaoBao.com 的电商数据。数据集由 Jie 等人 收集并标注,文本列主要包含产品描述。下图显示了淘宝产品描述的一个示例。

Taobao product description. A rabbit toy for lunar new year decoration.

加载数据

我们已经对数据集进行了预处理,使其可以直接与 AutoMM 一起使用。

import autogluon.multimodal
from autogluon.core.utils.loaders import load_pd
from autogluon.multimodal.utils import visualize_ner
train_data = load_pd.load('https://automl-mm-bench.s3.amazonaws.com/ner/taobao-ner/chinese_ner_train.csv')
dev_data = load_pd.load('https://automl-mm-bench.s3.amazonaws.com/ner/taobao-ner/chinese_ner_dev.csv')
train_data.head(5)
文本片段 实体标注
0 雄争霸点卡/七雄争霸元宝/七雄争霸100元1000元宝直充,自动充值 [{"entity_group": "HCCX", "start": 3, "end": 5...
1 简约韩版粗跟艾熙百思图亲子鞋冬季百搭街头母女圆头翻边绒面厚底 [{"entity_group": "HPPX", "start": 6, "end": 8...
2 羚跑商务背包双肩包男士防盗多功能出差韩版休闲15.6寸电脑包皮潮 [{"entity_group": "HPPX", "start": 0, "end": 2...
3 热水袋防爆充电暖宝Ÿœ卡通毛绒萌萌可爱注水暖宫暖手宝暖水袋 [{"entity_group": "HCCX", "start": 0, "end": 3...
4 童装11周岁13儿童夏装男童套装2017新款10中大童15男孩12秋季5潮7 [{"entity_group": "HCCX", "start": 0, "end": 2...

HPPX、HCCX、XH 和 MISC 分别代表品牌 (brand)、产品 (product)、图案 (pattern) 和杂项信息 (Miscellaneous information,例如产品规格 Specification)。让我们可视化一个关于在线游戏充值服务的示例。

visualize_ner(train_data["text_snippet"].iloc[0], train_data["entity_annotations"].iloc[0])
雄争霸点卡 HCCX /七雄争霸 MISC 元宝 HCCX /七雄争霸 MISC 100元 MISC 1000 MISC 元宝 HCCX 直充,自动充值

训练

使用 AutoMM 进行中文实体识别的过程与英文实体识别相同。您只需要选择一个适合的、已在中文或多语言文档上预训练的基础模型检查点。这里我们使用 'hfl/chinese-lert-small' 作为主干模型进行演示。

现在,通过设置 problem_type 为 ner 并指定标签列,创建一个用于命名实体识别的预测器。然后,我们调用 predictor.fit() 来训练模型几分钟。

from autogluon.multimodal import MultiModalPredictor
import uuid

label_col = "entity_annotations"
model_path = f"./tmp/{uuid.uuid4().hex}-automm_ner"  # You can rename it to the model path you like
predictor = MultiModalPredictor(problem_type="ner", label=label_col, path=model_path)
predictor.fit(
    train_data=train_data,
    hyperparameters={'model.ner_text.checkpoint_name':'hfl/chinese-lert-small'},
    time_limit=300, #second
)
=================== System Info ===================
AutoGluon Version:  1.3.1b20250508
Python Version:     3.11.9
Operating System:   Linux
Platform Machine:   x86_64
Platform Version:   #1 SMP Wed Mar 12 14:53:59 UTC 2025
CPU Count:          8
Pytorch Version:    2.6.0+cu124
CUDA Version:       12.4
Memory Avail:       28.39 GB / 30.95 GB (91.7%)
Disk Space Avail:   183.22 GB / 255.99 GB (71.6%)
===================================================

AutoMM starts to create your model. ✨✨✨

To track the learning progress, you can open a terminal and launch Tensorboard:
    ```shell
    # Assume you have installed tensorboard
    tensorboard --logdir /home/ci/autogluon/docs/tutorials/multimodal/text_prediction/tmp/b44a379585654f55b25b6582300ab75f-automm_ner
    ```
Seed set to 0
/home/ci/autogluon/multimodal/src/autogluon/multimodal/models/utils.py:1104: UserWarning: provided text_segment_num: 1 is smaller than hfl/chinese-lert-small's default: 2
  warnings.warn(
GPU Count: 1
GPU Count to be Used: 1
Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name              | Type              | Params | Mode 
----------------------------------------------------------------
0 | model             | HFAutoModelForNER | 15.1 M | train
1 | validation_metric | MulticlassF1Score | 0      | train
2 | loss_func         | CrossEntropyLoss  | 0      | train
----------------------------------------------------------------
15.1 M    Trainable params
0         Non-trainable params
15.1 M    Total params
60.345    Total estimated model params size (MB)
232       Modules in train mode
0         Modules in eval mode
Epoch 0, global step 21: 'val_ner_token_f1' reached 0.20986 (best 0.20986), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/text_prediction/tmp/b44a379585654f55b25b6582300ab75f-automm_ner/epoch=0-step=21.ckpt' as top 3
Epoch 0, global step 42: 'val_ner_token_f1' reached 0.61414 (best 0.61414), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/text_prediction/tmp/b44a379585654f55b25b6582300ab75f-automm_ner/epoch=0-step=42.ckpt' as top 3
Epoch 1, global step 64: 'val_ner_token_f1' reached 0.73071 (best 0.73071), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/text_prediction/tmp/b44a379585654f55b25b6582300ab75f-automm_ner/epoch=1-step=64.ckpt' as top 3
Epoch 1, global step 85: 'val_ner_token_f1' reached 0.73508 (best 0.73508), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/text_prediction/tmp/b44a379585654f55b25b6582300ab75f-automm_ner/epoch=1-step=85.ckpt' as top 3
Epoch 2, global step 107: 'val_ner_token_f1' reached 0.78553 (best 0.78553), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/text_prediction/tmp/b44a379585654f55b25b6582300ab75f-automm_ner/epoch=2-step=107.ckpt' as top 3
Epoch 2, global step 128: 'val_ner_token_f1' reached 0.81110 (best 0.81110), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/text_prediction/tmp/b44a379585654f55b25b6582300ab75f-automm_ner/epoch=2-step=128.ckpt' as top 3
Epoch 3, global step 150: 'val_ner_token_f1' reached 0.80350 (best 0.81110), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/text_prediction/tmp/b44a379585654f55b25b6582300ab75f-automm_ner/epoch=3-step=150.ckpt' as top 3
Epoch 3, global step 171: 'val_ner_token_f1' reached 0.82953 (best 0.82953), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/text_prediction/tmp/b44a379585654f55b25b6582300ab75f-automm_ner/epoch=3-step=171.ckpt' as top 3
Epoch 4, global step 193: 'val_ner_token_f1' reached 0.82999 (best 0.82999), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/text_prediction/tmp/b44a379585654f55b25b6582300ab75f-automm_ner/epoch=4-step=193.ckpt' as top 3
Epoch 4, global step 214: 'val_ner_token_f1' reached 0.85073 (best 0.85073), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/text_prediction/tmp/b44a379585654f55b25b6582300ab75f-automm_ner/epoch=4-step=214.ckpt' as top 3
Epoch 5, global step 236: 'val_ner_token_f1' reached 0.87399 (best 0.87399), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/text_prediction/tmp/b44a379585654f55b25b6582300ab75f-automm_ner/epoch=5-step=236.ckpt' as top 3
Epoch 5, global step 257: 'val_ner_token_f1' reached 0.84428 (best 0.87399), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/text_prediction/tmp/b44a379585654f55b25b6582300ab75f-automm_ner/epoch=5-step=257.ckpt' as top 3
Epoch 6, global step 279: 'val_ner_token_f1' reached 0.86570 (best 0.87399), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/text_prediction/tmp/b44a379585654f55b25b6582300ab75f-automm_ner/epoch=6-step=279.ckpt' as top 3
Epoch 6, global step 300: 'val_ner_token_f1' reached 0.88459 (best 0.88459), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/text_prediction/tmp/b44a379585654f55b25b6582300ab75f-automm_ner/epoch=6-step=300.ckpt' as top 3
Epoch 7, global step 322: 'val_ner_token_f1' reached 0.87146 (best 0.88459), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/text_prediction/tmp/b44a379585654f55b25b6582300ab75f-automm_ner/epoch=7-step=322.ckpt' as top 3
Epoch 7, global step 343: 'val_ner_token_f1' reached 0.87584 (best 0.88459), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/text_prediction/tmp/b44a379585654f55b25b6582300ab75f-automm_ner/epoch=7-step=343.ckpt' as top 3
Time limit reached. Elapsed time is 0:05:00. Signaling Trainer to stop.
Epoch 8, global step 358: 'val_ner_token_f1' reached 0.87722 (best 0.88459), saving model to '/home/ci/autogluon/docs/tutorials/multimodal/text_prediction/tmp/b44a379585654f55b25b6582300ab75f-automm_ner/epoch=8-step=358.ckpt' as top 3
Start to fuse 3 checkpoints via the greedy soup algorithm.
Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
AutoMM has created your model. 🎉🎉🎉

To load the model, use the code below:
    ```python
    from autogluon.multimodal import MultiModalPredictor
    predictor = MultiModalPredictor.load("/home/ci/autogluon/docs/tutorials/multimodal/text_prediction/tmp/b44a379585654f55b25b6582300ab75f-automm_ner")
    ```

If you are not satisfied with the model, try to increase the training time, 
adjust the hyperparameters (https://autogluon.cn/stable/tutorials/multimodal/advanced_topics/customization.html),
or post issues on GitHub (https://github.com/autogluon/autogluon/issues).
<autogluon.multimodal.predictor.MultiModalPredictor at 0x7f4fd1e08650>

评估

要检查模型在测试数据集上的性能,您只需要调用 predictor.evaluate(...)

predictor.evaluate(dev_data)
Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
{'hccx': {'precision': np.float64(0.7979464613127979),
  'recall': np.float64(0.8529988239905919),
  'f1': np.float64(0.8245547555892383),
  'number': np.int64(2551)},
 'hppx': {'precision': np.float64(0.6426229508196721),
  'recall': np.float64(0.7050359712230215),
  'f1': np.float64(0.6723842195540308),
  'number': np.int64(278)},
 'misc': {'precision': np.float64(0.7),
  'recall': np.float64(0.7361111111111112),
  'f1': np.float64(0.7176015473887815),
  'number': np.int64(504)},
 'xh': {'precision': np.float64(0.7272727272727273),
  'recall': np.float64(0.773109243697479),
  'f1': np.float64(0.7494908350305498),
  'number': np.int64(238)},
 'overall_precision': np.float64(0.7672346002621232),
 'overall_recall': np.float64(0.8196583590030804),
 'overall_f1': np.float64(0.792580557812077),
 'overall_accuracy': 0.884914356469698}

预测和可视化

通过调用 predictor.predict(...),您可以轻松获得给定输入句子的预测结果。

output = predictor.predict(dev_data)
visualize_ner(dev_data["text_snippet"].iloc[0], output[0])
Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
家用防尘厨房厨师帽子 HCCX 车间工厂鸭舌 HCCX 工作帽 HCCX 男女食堂餐厅食品卫生帽 HCCX

现在,让我们对兔子玩具的示例进行预测。

sentence = "2023年兔年挂件新年装饰品小挂饰乔迁之喜门挂小兔子"
predictions = predictor.predict({'text_snippet': [sentence]})
visualize_ner(sentence, predictions[0])
Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
2023年兔年挂件 HCCX 新年装饰品 HCCX 小挂饰 HCCX MISC HPPX HCCX 门挂 HCCX 小兔子 HCCX

其他示例

您可以访问 AutoMM Examples 来探索 AutoMM 的其他示例。

自定义

要了解如何自定义 AutoMM,请参考 自定义 AutoMM