代码示例 / 自然语言处理 / 使用 BART 进行抽象文本摘要

使用 BART 进行抽象文本摘要

作者: Abheesht Sharma
创建日期 2023/07/08
上次修改 2024/03/20
描述: 使用 KerasHub 在抽象摘要任务上微调 BART。

ⓘ 此示例使用 Keras 2

在 Colab 中查看 GitHub 源代码


介绍

在信息过载的时代,提取长文档或对话的核心内容并用几句话表达出来变得至关重要。由于摘要在不同领域具有广泛的应用,近年来它已成为一项关键的、经过充分研究的 NLP 任务。

双向自回归 Transformer (BART) 是一种基于 Transformer 的编码器-解码器模型,常用于摘要和神经机器翻译等序列到序列任务。BART 在大型文本语料库上以自监督方式进行预训练。在预训练期间,文本被破坏,BART 被训练用来重建原始文本(因此被称为“去噪自动编码器”)。一些预训练任务包括令牌掩码、令牌删除、句子置换(打乱句子并训练 BART 来修正顺序)等。

在本示例中,我们将演示如何使用 KerasHub 在抽象摘要任务(关于对话!)上微调 BART,并使用微调后的模型生成摘要。


设置

在我们开始实现流水线之前,让我们安装并导入所有需要的库。我们将使用 KerasHub 库。我们还需要几个实用程序库。

!pip install git+https://github.com/keras-team/keras-hub.git py7zr -q
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 66.4/66.4 kB 1.4 MB/s eta 0:00:00
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.1/2.1 MB 34.8 MB/s eta 0:00:00
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 412.3/412.3 kB 30.4 MB/s eta 0:00:00
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 138.8/138.8 kB 15.1 MB/s eta 0:00:00
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 49.8/49.8 kB 5.8 MB/s eta 0:00:00
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.7/2.7 MB 61.4 MB/s eta 0:00:00
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 93.1/93.1 kB 10.1 MB/s eta 0:00:00
[?25h  Building wheel for keras-hub (pyproject.toml) ... [?25l[?25hdone

此示例使用 Keras 3 在任何 "tensorflow""jax""torch" 中工作。对 Keras 3 的支持已内置到 KerasHub 中,只需更改 "KERAS_BACKEND" 环境变量即可选择您选择的后台。我们在下面选择 JAX 后台。

import os

os.environ["KERAS_BACKEND"] = "jax"

导入所有必要的库。

import py7zr
import time

import keras_hub
import keras
import tensorflow as tf
import tensorflow_datasets as tfds
Using JAX backend.

我们也来定义一下超参数。

BATCH_SIZE = 8
NUM_BATCHES = 600
EPOCHS = 1  # Can be set to a higher value for better results
MAX_ENCODER_SEQUENCE_LENGTH = 512
MAX_DECODER_SEQUENCE_LENGTH = 128
MAX_GENERATION_LENGTH = 40

数据集

让我们加载 SAMSum 数据集。该数据集包含约 15,000 对对话/对话和摘要。

# Download the dataset.
filename = keras.utils.get_file(
    "corpus.7z",
    origin="https://hugging-face.cn/datasets/samsum/resolve/main/data/corpus.7z",
)

# Extract the `.7z` file.
with py7zr.SevenZipFile(filename, mode="r") as z:
    z.extractall(path="/root/tensorflow_datasets/downloads/manual")

# Load data using TFDS.
samsum_ds = tfds.load("samsum", split="train", as_supervised=True)
Downloading data from https://hugging-face.cn/datasets/samsum/resolve/main/data/corpus.7z
 2944100/2944100 ━━━━━━━━━━━━━━━━━━━━ 1s 0us/step
Downloading and preparing dataset Unknown size (download: Unknown size, generated: 10.71 MiB, total: 10.71 MiB) to /root/tensorflow_datasets/samsum/1.0.0...

Generating splits...:   0%|          | 0/3 [00:00<?, ? splits/s]

Generating train examples...:   0%|          | 0/14732 [00:00<?, ? examples/s]

Shuffling /root/tensorflow_datasets/samsum/1.0.0.incompleteYA9MAV/samsum-train.tfrecord*...:   0%|          | …

Generating validation examples...:   0%|          | 0/818 [00:00<?, ? examples/s]

Shuffling /root/tensorflow_datasets/samsum/1.0.0.incompleteYA9MAV/samsum-validation.tfrecord*...:   0%|       …

Generating test examples...:   0%|          | 0/819 [00:00<?, ? examples/s]

Shuffling /root/tensorflow_datasets/samsum/1.0.0.incompleteYA9MAV/samsum-test.tfrecord*...:   0%|          | 0…

Dataset samsum downloaded and prepared to /root/tensorflow_datasets/samsum/1.0.0. Subsequent calls will reuse this data.

该数据集有两个字段:dialoguesummary。让我们看一个示例。

for dialogue, summary in samsum_ds:
    print(dialogue.numpy())
    print(summary.numpy())
    break
b"Carter: Hey Alexis, I just wanted to let you know that I had a really nice time with you tonight. \r\nAlexis: Thanks Carter. Yeah, I really enjoyed myself as well. \r\nCarter: If you are up for it, I would really like to see you again soon.\r\nAlexis: Thanks Carter, I'm flattered. But I have a really busy week coming up.\r\nCarter: Yeah, no worries. I totally understand. But if you ever want to go grab dinner again, just let me know. \r\nAlexis: Yeah of course. Thanks again for tonight. \r\nCarter: Sure. Have a great night. "
b'Alexis and Carter met tonight. Carter would like to meet again, but Alexis is busy.'

我们现在将对数据集进行批处理,并仅保留数据集的一个子集以用于本示例。对话被输入到编码器,相应的摘要作为解码器的输入。因此,我们将数据集的格式更改为包含两个键的字典:"encoder_text""decoder_text"。这就是 keras_hub.models.BartSeq2SeqLMPreprocessor 期望的输入格式。

train_ds = (
    samsum_ds.map(
        lambda dialogue, summary: {"encoder_text": dialogue, "decoder_text": summary}
    )
    .batch(BATCH_SIZE)
    .cache()
)
train_ds = train_ds.take(NUM_BATCHES)

微调 BART

让我们首先加载模型和预处理器。我们使用编码器和解码器的序列长度分别为 512 和 128,而不是 1024(这是默认的序列长度)。这将使我们能够在 Colab 上快速运行此示例。

如果您仔细观察,预处理器已附加到模型。这意味着我们不必担心文本输入的预处理;所有操作都将在内部完成。预处理器将对编码器文本和解码器文本进行标记化,添加特殊标记并对其进行填充。为了生成用于自回归训练的标签,预处理器将解码器文本向右移动一位。这是因为在每个时间步长,模型都经过训练以预测下一个标记。

preprocessor = keras_hub.models.BartSeq2SeqLMPreprocessor.from_preset(
    "bart_base_en",
    encoder_sequence_length=MAX_ENCODER_SEQUENCE_LENGTH,
    decoder_sequence_length=MAX_DECODER_SEQUENCE_LENGTH,
)
bart_lm = keras_hub.models.BartSeq2SeqLM.from_preset(
    "bart_base_en", preprocessor=preprocessor
)

bart_lm.summary()
Downloading data from https://storage.googleapis.com/keras-hub/models/bart_base_en/v1/vocab.json
 898823/898823 ━━━━━━━━━━━━━━━━━━━━ 1s 1us/step
Downloading data from https://storage.googleapis.com/keras-hub/models/bart_base_en/v1/merges.txt
 456318/456318 ━━━━━━━━━━━━━━━━━━━━ 1s 1us/step
Downloading data from https://storage.googleapis.com/keras-hub/models/bart_base_en/v1/model.h5
 557969120/557969120 ━━━━━━━━━━━━━━━━━━━━ 29s 0us/step
Preprocessor: "bart_seq2_seq_lm_preprocessor"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Tokenizer (type)                                                                                Vocab # ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ bart_tokenizer (BartTokenizer)                     │                                              50,265 │
└────────────────────────────────────────────────────┴─────────────────────────────────────────────────────┘
Model: "bart_seq2_seq_lm"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Layer (type)                   Output Shape                   Param #  Connected to                   ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ decoder_padding_mask          │ (None, None)              │           0 │ -                              │
│ (InputLayer)                  │                           │             │                                │
├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤
│ decoder_token_ids             │ (None, None)              │           0 │ -                              │
│ (InputLayer)                  │                           │             │                                │
├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤
│ encoder_padding_mask          │ (None, None)              │           0 │ -                              │
│ (InputLayer)                  │                           │             │                                │
├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤
│ encoder_token_ids             │ (None, None)              │           0 │ -                              │
│ (InputLayer)                  │                           │             │                                │
├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤
│ bart_backbone (BartBackbone)  │ [(None, None, 768),       │ 139,417,344 │ decoder_padding_mask[0][0],    │
│                               │ (None, None, 768)]        │             │ decoder_token_ids[0][0],       │
│                               │                           │             │ encoder_padding_mask[0][0],    │
│                               │                           │             │ encoder_token_ids[0][0]        │
├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤
│ reverse_embedding             │ (None, 50265)             │  38,603,520 │ bart_backbone[0][0]            │
│ (ReverseEmbedding)            │                           │             │                                │
└───────────────────────────────┴───────────────────────────┴─────────────┴────────────────────────────────┘
 Total params: 139,417,344 (4.15 GB)
 Trainable params: 139,417,344 (4.15 GB)
 Non-trainable params: 0 (0.00 B)

定义优化器和损失。我们使用带有线性衰减学习率的 Adam 优化器。编译模型。

optimizer = keras.optimizers.AdamW(
    learning_rate=5e-5,
    weight_decay=0.01,
    epsilon=1e-6,
    global_clipnorm=1.0,  # Gradient clipping.
)
# Exclude layernorm and bias terms from weight decay.
optimizer.exclude_from_weight_decay(var_names=["bias"])
optimizer.exclude_from_weight_decay(var_names=["gamma"])
optimizer.exclude_from_weight_decay(var_names=["beta"])

loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)

bart_lm.compile(
    optimizer=optimizer,
    loss=loss,
    weighted_metrics=["accuracy"],
)

让我们训练模型!

bart_lm.fit(train_ds, epochs=EPOCHS)
 600/600 ━━━━━━━━━━━━━━━━━━━━ 398s 586ms/step - loss: 0.4330

<keras_core.src.callbacks.history.History at 0x7ae2faf3e110>

生成摘要并对其进行评估!

现在模型已经训练好了,让我们开始有趣的部分——生成摘要!让我们从验证集中选择前 100 个样本并为其生成摘要。我们将使用默认的解码策略,即贪婪搜索。

KerasHub 中的生成高度优化。它由 XLA 的强大功能支持。其次,解码器中自注意力层和交叉注意力层的键值张量被缓存,以避免在每个时间步长重新计算。

def generate_text(model, input_text, max_length=200, print_time_taken=False):
    start = time.time()
    output = model.generate(input_text, max_length=max_length)
    end = time.time()
    print(f"Total Time Elapsed: {end - start:.2f}s")
    return output


# Load the dataset.
val_ds = tfds.load("samsum", split="validation", as_supervised=True)
val_ds = val_ds.take(100)

dialogues = []
ground_truth_summaries = []
for dialogue, summary in val_ds:
    dialogues.append(dialogue.numpy())
    ground_truth_summaries.append(summary.numpy())

# Let's make a dummy call - the first call to XLA generally takes a bit longer.
_ = generate_text(bart_lm, "sample text", max_length=MAX_GENERATION_LENGTH)

# Generate summaries.
generated_summaries = generate_text(
    bart_lm,
    val_ds.map(lambda dialogue, _: dialogue).batch(8),
    max_length=MAX_GENERATION_LENGTH,
    print_time_taken=True,
)
Total Time Elapsed: 21.22s
Total Time Elapsed: 49.00s

让我们看看一些摘要。

for dialogue, generated_summary, ground_truth_summary in zip(
    dialogues[:5], generated_summaries[:5], ground_truth_summaries[:5]
):
    print("Dialogue:", dialogue)
    print("Generated Summary:", generated_summary)
    print("Ground Truth Summary:", ground_truth_summary)
    print("=============================")
Dialogue: b'Tony: Is the boss in?\r\nClaire: Not yet.\r\nTony: Could let me know when he comes, please? \r\nClaire: Of course.\r\nTony: Thank you.'
Generated Summary: Tony will let Claire know when her boss comes.
Ground Truth Summary: b"The boss isn't in yet. Claire will let Tony know when he comes."
=============================
Dialogue: b"James: What shouldl I get her?\r\nTim: who?\r\nJames: gees Mary my girlfirend\r\nTim: Am I really the person you should be asking?\r\nJames: oh come on it's her birthday on Sat\r\nTim: ask Sandy\r\nTim: I honestly am not the right person to ask this\r\nJames: ugh fine!"
Generated Summary: Mary's girlfriend is birthday. James and Tim are going to ask Sandy to buy her.
Ground Truth Summary: b"Mary's birthday is on Saturday. Her boyfriend, James, is looking for gift ideas. Tim suggests that he ask Sandy."
=============================
Dialogue: b"Mary: So, how's Israel? Have you been on the beach?\r\nKate: It's so expensive! But they say, it's Tel Aviv... Tomorrow we are going to Jerusalem.\r\nMary: I've heard Israel is expensive, Monica was there on vacation last year, she complained about how pricey it is. Are you going to the Dead Sea before it dies? ahahahha\r\nKate: ahahhaha yup, in few days."
Generated Summary: Kate is on vacation in Tel Aviv. Mary will visit the Dead Sea in a few days.
Ground Truth Summary: b'Mary and Kate discuss how expensive Israel is. Kate is in Tel Aviv now, planning to travel to Jerusalem tomorrow, and to the Dead Sea few days later.'
=============================
Dialogue: b"Giny: do we have rice?\r\nRiley: nope, it's finished\r\nGiny: fuck!\r\nGiny: ok, I'll buy"
Generated Summary: Giny wants to buy rice from Riley.
Ground Truth Summary: b"Giny and Riley don't have any rice left. Giny will buy some."
=============================
Dialogue: b"Jude: i'll be in warsaw at the beginning of december so we could meet again\r\nLeon: !!!\r\nLeon: at the beginning means...?\r\nLeon: cuz I won't be here during the first weekend\r\nJude: 10\r\nJude: but i think it's a monday, so never mind i guess :D\r\nLeon: yeah monday doesn't really work for me :D\r\nLeon: :<\r\nJude: oh well next time :d\r\nLeon: yeah...!"
Generated Summary: Jude and Leon will meet again this weekend at 10 am.
Ground Truth Summary: b'Jude is coming to Warsaw on the 10th of December and wants to see Leon. Leon has no time.'
=============================

生成的摘要看起来很棒!对于只在一个 epoch 上训练了 5000 个示例的模型来说还不错 :)