作者: Abheesht Sharma
创建日期 2023/07/08
上次修改 2024/03/20
描述: 使用 KerasHub 在抽象摘要任务上微调 BART。
在信息过载的时代,提取长文档或对话的核心内容并用几句话表达出来变得至关重要。由于摘要在不同领域具有广泛的应用,近年来它已成为一项关键的、经过充分研究的 NLP 任务。
双向自回归 Transformer (BART) 是一种基于 Transformer 的编码器-解码器模型,常用于摘要和神经机器翻译等序列到序列任务。BART 在大型文本语料库上以自监督方式进行预训练。在预训练期间,文本被破坏,BART 被训练用来重建原始文本(因此被称为“去噪自动编码器”)。一些预训练任务包括令牌掩码、令牌删除、句子置换(打乱句子并训练 BART 来修正顺序)等。
在本示例中,我们将演示如何使用 KerasHub 在抽象摘要任务(关于对话!)上微调 BART,并使用微调后的模型生成摘要。
在我们开始实现流水线之前,让我们安装并导入所有需要的库。我们将使用 KerasHub 库。我们还需要几个实用程序库。
!pip install git+https://github.com/keras-team/keras-hub.git py7zr -q
Installing build dependencies ... [?25l[?25hdone
Getting requirements to build wheel ... [?25l[?25hdone
Preparing metadata (pyproject.toml) ... [?25l[?25hdone
[2K [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 66.4/66.4 kB [31m1.4 MB/s eta [36m0:00:00
[2K [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.1/2.1 MB [31m34.8 MB/s eta [36m0:00:00
[2K [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 412.3/412.3 kB [31m30.4 MB/s eta [36m0:00:00
[2K [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 138.8/138.8 kB [31m15.1 MB/s eta [36m0:00:00
[2K [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 49.8/49.8 kB [31m5.8 MB/s eta [36m0:00:00
[2K [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.7/2.7 MB [31m61.4 MB/s eta [36m0:00:00
[2K [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 93.1/93.1 kB [31m10.1 MB/s eta [36m0:00:00
[?25h Building wheel for keras-hub (pyproject.toml) ... [?25l[?25hdone
此示例使用 Keras 3 在任何 "tensorflow"
、"jax"
或 "torch"
中工作。对 Keras 3 的支持已内置到 KerasHub 中,只需更改 "KERAS_BACKEND"
环境变量即可选择您选择的后台。我们在下面选择 JAX 后台。
import os
os.environ["KERAS_BACKEND"] = "jax"
导入所有必要的库。
import py7zr
import time
import keras_hub
import keras
import tensorflow as tf
import tensorflow_datasets as tfds
Using JAX backend.
我们也来定义一下超参数。
BATCH_SIZE = 8
NUM_BATCHES = 600
EPOCHS = 1 # Can be set to a higher value for better results
MAX_ENCODER_SEQUENCE_LENGTH = 512
MAX_DECODER_SEQUENCE_LENGTH = 128
MAX_GENERATION_LENGTH = 40
让我们加载 SAMSum 数据集。该数据集包含约 15,000 对对话/对话和摘要。
# Download the dataset.
filename = keras.utils.get_file(
"corpus.7z",
origin="https://hugging-face.cn/datasets/samsum/resolve/main/data/corpus.7z",
)
# Extract the `.7z` file.
with py7zr.SevenZipFile(filename, mode="r") as z:
z.extractall(path="/root/tensorflow_datasets/downloads/manual")
# Load data using TFDS.
samsum_ds = tfds.load("samsum", split="train", as_supervised=True)
Downloading data from https://hugging-face.cn/datasets/samsum/resolve/main/data/corpus.7z
2944100/2944100 ━━━━━━━━━━━━━━━━━━━━ 1s 0us/step
Downloading and preparing dataset Unknown size (download: Unknown size, generated: 10.71 MiB, total: 10.71 MiB) to /root/tensorflow_datasets/samsum/1.0.0...
Generating splits...: 0%| | 0/3 [00:00<?, ? splits/s]
Generating train examples...: 0%| | 0/14732 [00:00<?, ? examples/s]
Shuffling /root/tensorflow_datasets/samsum/1.0.0.incompleteYA9MAV/samsum-train.tfrecord*...: 0%| | …
Generating validation examples...: 0%| | 0/818 [00:00<?, ? examples/s]
Shuffling /root/tensorflow_datasets/samsum/1.0.0.incompleteYA9MAV/samsum-validation.tfrecord*...: 0%| …
Generating test examples...: 0%| | 0/819 [00:00<?, ? examples/s]
Shuffling /root/tensorflow_datasets/samsum/1.0.0.incompleteYA9MAV/samsum-test.tfrecord*...: 0%| | 0…
Dataset samsum downloaded and prepared to /root/tensorflow_datasets/samsum/1.0.0. Subsequent calls will reuse this data.
该数据集有两个字段:dialogue
和 summary
。让我们看一个示例。
for dialogue, summary in samsum_ds:
print(dialogue.numpy())
print(summary.numpy())
break
b"Carter: Hey Alexis, I just wanted to let you know that I had a really nice time with you tonight. \r\nAlexis: Thanks Carter. Yeah, I really enjoyed myself as well. \r\nCarter: If you are up for it, I would really like to see you again soon.\r\nAlexis: Thanks Carter, I'm flattered. But I have a really busy week coming up.\r\nCarter: Yeah, no worries. I totally understand. But if you ever want to go grab dinner again, just let me know. \r\nAlexis: Yeah of course. Thanks again for tonight. \r\nCarter: Sure. Have a great night. "
b'Alexis and Carter met tonight. Carter would like to meet again, but Alexis is busy.'
我们现在将对数据集进行批处理,并仅保留数据集的一个子集以用于本示例。对话被输入到编码器,相应的摘要作为解码器的输入。因此,我们将数据集的格式更改为包含两个键的字典:"encoder_text"
和 "decoder_text"
。这就是 keras_hub.models.BartSeq2SeqLMPreprocessor
期望的输入格式。
train_ds = (
samsum_ds.map(
lambda dialogue, summary: {"encoder_text": dialogue, "decoder_text": summary}
)
.batch(BATCH_SIZE)
.cache()
)
train_ds = train_ds.take(NUM_BATCHES)
让我们首先加载模型和预处理器。我们使用编码器和解码器的序列长度分别为 512 和 128,而不是 1024(这是默认的序列长度)。这将使我们能够在 Colab 上快速运行此示例。
如果您仔细观察,预处理器已附加到模型。这意味着我们不必担心文本输入的预处理;所有操作都将在内部完成。预处理器将对编码器文本和解码器文本进行标记化,添加特殊标记并对其进行填充。为了生成用于自回归训练的标签,预处理器将解码器文本向右移动一位。这是因为在每个时间步长,模型都经过训练以预测下一个标记。
preprocessor = keras_hub.models.BartSeq2SeqLMPreprocessor.from_preset(
"bart_base_en",
encoder_sequence_length=MAX_ENCODER_SEQUENCE_LENGTH,
decoder_sequence_length=MAX_DECODER_SEQUENCE_LENGTH,
)
bart_lm = keras_hub.models.BartSeq2SeqLM.from_preset(
"bart_base_en", preprocessor=preprocessor
)
bart_lm.summary()
Downloading data from https://storage.googleapis.com/keras-hub/models/bart_base_en/v1/vocab.json
898823/898823 ━━━━━━━━━━━━━━━━━━━━ 1s 1us/step
Downloading data from https://storage.googleapis.com/keras-hub/models/bart_base_en/v1/merges.txt
456318/456318 ━━━━━━━━━━━━━━━━━━━━ 1s 1us/step
Downloading data from https://storage.googleapis.com/keras-hub/models/bart_base_en/v1/model.h5
557969120/557969120 ━━━━━━━━━━━━━━━━━━━━ 29s 0us/step
Preprocessor: "bart_seq2_seq_lm_preprocessor"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃ Tokenizer (type) ┃ Vocab # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ │ bart_tokenizer (BartTokenizer) │ 50,265 │ └────────────────────────────────────────────────────┴─────────────────────────────────────────────────────┘
Model: "bart_seq2_seq_lm"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ Connected to ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ │ decoder_padding_mask │ (None, None) │ 0 │ - │ │ (InputLayer) │ │ │ │ ├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤ │ decoder_token_ids │ (None, None) │ 0 │ - │ │ (InputLayer) │ │ │ │ ├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤ │ encoder_padding_mask │ (None, None) │ 0 │ - │ │ (InputLayer) │ │ │ │ ├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤ │ encoder_token_ids │ (None, None) │ 0 │ - │ │ (InputLayer) │ │ │ │ ├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤ │ bart_backbone (BartBackbone) │ [(None, None, 768), │ 139,417,344 │ decoder_padding_mask[0][0], │ │ │ (None, None, 768)] │ │ decoder_token_ids[0][0], │ │ │ │ │ encoder_padding_mask[0][0], │ │ │ │ │ encoder_token_ids[0][0] │ ├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤ │ reverse_embedding │ (None, 50265) │ 38,603,520 │ bart_backbone[0][0] │ │ (ReverseEmbedding) │ │ │ │ └───────────────────────────────┴───────────────────────────┴─────────────┴────────────────────────────────┘
Total params: 139,417,344 (4.15 GB)
Trainable params: 139,417,344 (4.15 GB)
Non-trainable params: 0 (0.00 B)
定义优化器和损失。我们使用带有线性衰减学习率的 Adam 优化器。编译模型。
optimizer = keras.optimizers.AdamW(
learning_rate=5e-5,
weight_decay=0.01,
epsilon=1e-6,
global_clipnorm=1.0, # Gradient clipping.
)
# Exclude layernorm and bias terms from weight decay.
optimizer.exclude_from_weight_decay(var_names=["bias"])
optimizer.exclude_from_weight_decay(var_names=["gamma"])
optimizer.exclude_from_weight_decay(var_names=["beta"])
loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
bart_lm.compile(
optimizer=optimizer,
loss=loss,
weighted_metrics=["accuracy"],
)
让我们训练模型!
bart_lm.fit(train_ds, epochs=EPOCHS)
600/600 ━━━━━━━━━━━━━━━━━━━━ 398s 586ms/step - loss: 0.4330
<keras_core.src.callbacks.history.History at 0x7ae2faf3e110>
现在模型已经训练好了,让我们开始有趣的部分——生成摘要!让我们从验证集中选择前 100 个样本并为其生成摘要。我们将使用默认的解码策略,即贪婪搜索。
KerasHub 中的生成高度优化。它由 XLA 的强大功能支持。其次,解码器中自注意力层和交叉注意力层的键值张量被缓存,以避免在每个时间步长重新计算。
def generate_text(model, input_text, max_length=200, print_time_taken=False):
start = time.time()
output = model.generate(input_text, max_length=max_length)
end = time.time()
print(f"Total Time Elapsed: {end - start:.2f}s")
return output
# Load the dataset.
val_ds = tfds.load("samsum", split="validation", as_supervised=True)
val_ds = val_ds.take(100)
dialogues = []
ground_truth_summaries = []
for dialogue, summary in val_ds:
dialogues.append(dialogue.numpy())
ground_truth_summaries.append(summary.numpy())
# Let's make a dummy call - the first call to XLA generally takes a bit longer.
_ = generate_text(bart_lm, "sample text", max_length=MAX_GENERATION_LENGTH)
# Generate summaries.
generated_summaries = generate_text(
bart_lm,
val_ds.map(lambda dialogue, _: dialogue).batch(8),
max_length=MAX_GENERATION_LENGTH,
print_time_taken=True,
)
Total Time Elapsed: 21.22s
Total Time Elapsed: 49.00s
让我们看看一些摘要。
for dialogue, generated_summary, ground_truth_summary in zip(
dialogues[:5], generated_summaries[:5], ground_truth_summaries[:5]
):
print("Dialogue:", dialogue)
print("Generated Summary:", generated_summary)
print("Ground Truth Summary:", ground_truth_summary)
print("=============================")
Dialogue: b'Tony: Is the boss in?\r\nClaire: Not yet.\r\nTony: Could let me know when he comes, please? \r\nClaire: Of course.\r\nTony: Thank you.'
Generated Summary: Tony will let Claire know when her boss comes.
Ground Truth Summary: b"The boss isn't in yet. Claire will let Tony know when he comes."
=============================
Dialogue: b"James: What shouldl I get her?\r\nTim: who?\r\nJames: gees Mary my girlfirend\r\nTim: Am I really the person you should be asking?\r\nJames: oh come on it's her birthday on Sat\r\nTim: ask Sandy\r\nTim: I honestly am not the right person to ask this\r\nJames: ugh fine!"
Generated Summary: Mary's girlfriend is birthday. James and Tim are going to ask Sandy to buy her.
Ground Truth Summary: b"Mary's birthday is on Saturday. Her boyfriend, James, is looking for gift ideas. Tim suggests that he ask Sandy."
=============================
Dialogue: b"Mary: So, how's Israel? Have you been on the beach?\r\nKate: It's so expensive! But they say, it's Tel Aviv... Tomorrow we are going to Jerusalem.\r\nMary: I've heard Israel is expensive, Monica was there on vacation last year, she complained about how pricey it is. Are you going to the Dead Sea before it dies? ahahahha\r\nKate: ahahhaha yup, in few days."
Generated Summary: Kate is on vacation in Tel Aviv. Mary will visit the Dead Sea in a few days.
Ground Truth Summary: b'Mary and Kate discuss how expensive Israel is. Kate is in Tel Aviv now, planning to travel to Jerusalem tomorrow, and to the Dead Sea few days later.'
=============================
Dialogue: b"Giny: do we have rice?\r\nRiley: nope, it's finished\r\nGiny: fuck!\r\nGiny: ok, I'll buy"
Generated Summary: Giny wants to buy rice from Riley.
Ground Truth Summary: b"Giny and Riley don't have any rice left. Giny will buy some."
=============================
Dialogue: b"Jude: i'll be in warsaw at the beginning of december so we could meet again\r\nLeon: !!!\r\nLeon: at the beginning means...?\r\nLeon: cuz I won't be here during the first weekend\r\nJude: 10\r\nJude: but i think it's a monday, so never mind i guess :D\r\nLeon: yeah monday doesn't really work for me :D\r\nLeon: :<\r\nJude: oh well next time :d\r\nLeon: yeah...!"
Generated Summary: Jude and Leon will meet again this weekend at 10 am.
Ground Truth Summary: b'Jude is coming to Warsaw on the 10th of December and wants to see Leon. Leon has no time.'
=============================
生成的摘要看起来很棒!对于只在一个 epoch 上训练了 5000 个示例的模型来说还不错 :)