作者: Abheesht Sharma
创建日期 2023/07/08
上次修改日期 2024/03/20
描述:使用 KerasHub 对 BART 进行微调,以完成抽象摘要任务。
在信息过载的时代,提取长文档或对话的核心内容并将其用几句话表达出来变得至关重要。由于摘要在不同领域有着广泛的应用,近年来它已成为一个关键且得到充分研究的 NLP 任务。
双向自回归Transformer (BART) 是一种基于 Transformer 的编码器-解码器模型,通常用于序列到序列的任务,如摘要和神经机器翻译。BART 以自监督的方式在一个大型文本语料库上进行预训练。在预训练期间,文本会被损坏,并且 BART 被训练以重建原始文本(因此称为“去噪自动编码器”)。一些预训练任务包括标记掩码、标记删除、句子置换(随机打乱句子并训练 BART 以恢复顺序)等。
在本例中,我们将演示如何使用 KerasHub 对 BART 进行微调,以完成抽象摘要任务(在对话中!),并使用微调后的模型生成摘要。
在开始实现流程之前,让我们安装并导入所有需要的库。我们将使用 KerasHub 库。我们还需要几个实用程序库。
!pip install git+https://github.com/keras-team/keras-hub.git py7zr -q
Installing build dependencies ... [?25l[?25hdone
Getting requirements to build wheel ... [?25l[?25hdone
Preparing metadata (pyproject.toml) ... [?25l[?25hdone
[2K [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 66.4/66.4 kB [31m1.4 MB/s eta [36m0:00:00
[2K [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.1/2.1 MB [31m34.8 MB/s eta [36m0:00:00
[2K [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 412.3/412.3 kB [31m30.4 MB/s eta [36m0:00:00
[2K [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 138.8/138.8 kB [31m15.1 MB/s eta [36m0:00:00
[2K [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 49.8/49.8 kB [31m5.8 MB/s eta [36m0:00:00
[2K [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.7/2.7 MB [31m61.4 MB/s eta [36m0:00:00
[2K [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 93.1/93.1 kB [31m10.1 MB/s eta [36m0:00:00
[?25h Building wheel for keras-hub (pyproject.toml) ... [?25l[?25hdone
此示例使用 Keras 3 在 "tensorflow"
、"jax"
或 "torch"
中运行。KerasHub 内置了对 Keras 3 的支持,只需更改 "KERAS_BACKEND"
环境变量即可选择你想要的后台。我们在下面选择 JAX 后台。
import os
os.environ["KERAS_BACKEND"] = "jax"
导入所有必要的库。
import py7zr
import time
import keras_hub
import keras
import tensorflow as tf
import tensorflow_datasets as tfds
Using JAX backend.
让我们也定义我们的超参数。
BATCH_SIZE = 8
NUM_BATCHES = 600
EPOCHS = 1 # Can be set to a higher value for better results
MAX_ENCODER_SEQUENCE_LENGTH = 512
MAX_DECODER_SEQUENCE_LENGTH = 128
MAX_GENERATION_LENGTH = 40
让我们加载 SAMSum 数据集。此数据集包含大约 15,000 对对话/对话和摘要。
# Download the dataset.
filename = keras.utils.get_file(
"corpus.7z",
origin="https://hugging-face.cn/datasets/samsum/resolve/main/data/corpus.7z",
)
# Extract the `.7z` file.
with py7zr.SevenZipFile(filename, mode="r") as z:
z.extractall(path="/root/tensorflow_datasets/downloads/manual")
# Load data using TFDS.
samsum_ds = tfds.load("samsum", split="train", as_supervised=True)
Downloading data from https://hugging-face.cn/datasets/samsum/resolve/main/data/corpus.7z
2944100/2944100 ━━━━━━━━━━━━━━━━━━━━ 1s 0us/step
Downloading and preparing dataset Unknown size (download: Unknown size, generated: 10.71 MiB, total: 10.71 MiB) to /root/tensorflow_datasets/samsum/1.0.0...
Generating splits...: 0%| | 0/3 [00:00<?, ? splits/s]
Generating train examples...: 0%| | 0/14732 [00:00<?, ? examples/s]
Shuffling /root/tensorflow_datasets/samsum/1.0.0.incompleteYA9MAV/samsum-train.tfrecord*...: 0%| | …
Generating validation examples...: 0%| | 0/818 [00:00<?, ? examples/s]
Shuffling /root/tensorflow_datasets/samsum/1.0.0.incompleteYA9MAV/samsum-validation.tfrecord*...: 0%| …
Generating test examples...: 0%| | 0/819 [00:00<?, ? examples/s]
Shuffling /root/tensorflow_datasets/samsum/1.0.0.incompleteYA9MAV/samsum-test.tfrecord*...: 0%| | 0…
Dataset samsum downloaded and prepared to /root/tensorflow_datasets/samsum/1.0.0. Subsequent calls will reuse this data.
数据集有两个字段:dialogue
和 summary
。让我们看一个示例。
for dialogue, summary in samsum_ds:
print(dialogue.numpy())
print(summary.numpy())
break
b"Carter: Hey Alexis, I just wanted to let you know that I had a really nice time with you tonight. \r\nAlexis: Thanks Carter. Yeah, I really enjoyed myself as well. \r\nCarter: If you are up for it, I would really like to see you again soon.\r\nAlexis: Thanks Carter, I'm flattered. But I have a really busy week coming up.\r\nCarter: Yeah, no worries. I totally understand. But if you ever want to go grab dinner again, just let me know. \r\nAlexis: Yeah of course. Thanks again for tonight. \r\nCarter: Sure. Have a great night. "
b'Alexis and Carter met tonight. Carter would like to meet again, but Alexis is busy.'
我们现在将对数据集进行批处理,并仅保留数据集的一个子集以用于此示例。对话被馈送到编码器,相应的摘要作为解码器的输入。因此,我们将数据集的格式更改为包含两个键的字典:"encoder_text"
和 "decoder_text"
。这是 keras_hub.models.BartSeq2SeqLMPreprocessor
期望的输入格式。
train_ds = (
samsum_ds.map(
lambda dialogue, summary: {"encoder_text": dialogue, "decoder_text": summary}
)
.batch(BATCH_SIZE)
.cache()
)
train_ds = train_ds.take(NUM_BATCHES)
让我们首先加载模型和预处理器。我们分别对编码器和解码器使用 512 和 128 的序列长度,而不是 1024(这是默认序列长度)。这将使我们能够在 Colab 上快速运行此示例。
如果你仔细观察,预处理器附加到模型上。这意味着我们不必担心预处理文本输入;所有操作都将在内部完成。预处理器对编码器文本和解码器文本进行标记化,添加特殊标记并对其进行填充。为了生成自回归训练的标签,预处理器将解码器文本向右移动一个位置。这样做是因为在每个时间步长,模型都被训练以预测下一个标记。
preprocessor = keras_hub.models.BartSeq2SeqLMPreprocessor.from_preset(
"bart_base_en",
encoder_sequence_length=MAX_ENCODER_SEQUENCE_LENGTH,
decoder_sequence_length=MAX_DECODER_SEQUENCE_LENGTH,
)
bart_lm = keras_hub.models.BartSeq2SeqLM.from_preset(
"bart_base_en", preprocessor=preprocessor
)
bart_lm.summary()
Downloading data from https://storage.googleapis.com/keras-hub/models/bart_base_en/v1/vocab.json
898823/898823 ━━━━━━━━━━━━━━━━━━━━ 1s 1us/step
Downloading data from https://storage.googleapis.com/keras-hub/models/bart_base_en/v1/merges.txt
456318/456318 ━━━━━━━━━━━━━━━━━━━━ 1s 1us/step
Downloading data from https://storage.googleapis.com/keras-hub/models/bart_base_en/v1/model.h5
557969120/557969120 ━━━━━━━━━━━━━━━━━━━━ 29s 0us/step
Preprocessor: "bart_seq2_seq_lm_preprocessor"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃ Tokenizer (type) ┃ Vocab # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ │ bart_tokenizer (BartTokenizer) │ 50,265 │ └────────────────────────────────────────────────────┴─────────────────────────────────────────────────────┘
Model: "bart_seq2_seq_lm"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ Connected to ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ │ decoder_padding_mask │ (None, None) │ 0 │ - │ │ (InputLayer) │ │ │ │ ├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤ │ decoder_token_ids │ (None, None) │ 0 │ - │ │ (InputLayer) │ │ │ │ ├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤ │ encoder_padding_mask │ (None, None) │ 0 │ - │ │ (InputLayer) │ │ │ │ ├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤ │ encoder_token_ids │ (None, None) │ 0 │ - │ │ (InputLayer) │ │ │ │ ├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤ │ bart_backbone (BartBackbone) │ [(None, None, 768), │ 139,417,344 │ decoder_padding_mask[0][0], │ │ │ (None, None, 768)] │ │ decoder_token_ids[0][0], │ │ │ │ │ encoder_padding_mask[0][0], │ │ │ │ │ encoder_token_ids[0][0] │ ├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤ │ reverse_embedding │ (None, 50265) │ 38,603,520 │ bart_backbone[0][0] │ │ (ReverseEmbedding) │ │ │ │ └───────────────────────────────┴───────────────────────────┴─────────────┴────────────────────────────────┘
Total params: 139,417,344 (4.15 GB)
Trainable params: 139,417,344 (4.15 GB)
Non-trainable params: 0 (0.00 B)
定义优化器和损失函数。我们使用 Adam 优化器,并使用线性衰减学习率。编译模型。
optimizer = keras.optimizers.AdamW(
learning_rate=5e-5,
weight_decay=0.01,
epsilon=1e-6,
global_clipnorm=1.0, # Gradient clipping.
)
# Exclude layernorm and bias terms from weight decay.
optimizer.exclude_from_weight_decay(var_names=["bias"])
optimizer.exclude_from_weight_decay(var_names=["gamma"])
optimizer.exclude_from_weight_decay(var_names=["beta"])
loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
bart_lm.compile(
optimizer=optimizer,
loss=loss,
weighted_metrics=["accuracy"],
)
让我们训练模型!
bart_lm.fit(train_ds, epochs=EPOCHS)
600/600 ━━━━━━━━━━━━━━━━━━━━ 398s 586ms/step - loss: 0.4330
<keras_core.src.callbacks.history.History at 0x7ae2faf3e110>
现在模型已经过训练,让我们进入有趣的部分——实际生成摘要!让我们从验证集中选择前 100 个样本并为其生成摘要。我们将使用默认的解码策略,即贪婪搜索。
KerasHub 中的生成高度优化。它由 XLA 的强大功能支持。其次,解码器中自注意力层和交叉注意力层的键/值张量被缓存,以避免在每个时间步长重新计算。
def generate_text(model, input_text, max_length=200, print_time_taken=False):
start = time.time()
output = model.generate(input_text, max_length=max_length)
end = time.time()
print(f"Total Time Elapsed: {end - start:.2f}s")
return output
# Load the dataset.
val_ds = tfds.load("samsum", split="validation", as_supervised=True)
val_ds = val_ds.take(100)
dialogues = []
ground_truth_summaries = []
for dialogue, summary in val_ds:
dialogues.append(dialogue.numpy())
ground_truth_summaries.append(summary.numpy())
# Let's make a dummy call - the first call to XLA generally takes a bit longer.
_ = generate_text(bart_lm, "sample text", max_length=MAX_GENERATION_LENGTH)
# Generate summaries.
generated_summaries = generate_text(
bart_lm,
val_ds.map(lambda dialogue, _: dialogue).batch(8),
max_length=MAX_GENERATION_LENGTH,
print_time_taken=True,
)
Total Time Elapsed: 21.22s
Total Time Elapsed: 49.00s
让我们看看一些摘要。
for dialogue, generated_summary, ground_truth_summary in zip(
dialogues[:5], generated_summaries[:5], ground_truth_summaries[:5]
):
print("Dialogue:", dialogue)
print("Generated Summary:", generated_summary)
print("Ground Truth Summary:", ground_truth_summary)
print("=============================")
Dialogue: b'Tony: Is the boss in?\r\nClaire: Not yet.\r\nTony: Could let me know when he comes, please? \r\nClaire: Of course.\r\nTony: Thank you.'
Generated Summary: Tony will let Claire know when her boss comes.
Ground Truth Summary: b"The boss isn't in yet. Claire will let Tony know when he comes."
=============================
Dialogue: b"James: What shouldl I get her?\r\nTim: who?\r\nJames: gees Mary my girlfirend\r\nTim: Am I really the person you should be asking?\r\nJames: oh come on it's her birthday on Sat\r\nTim: ask Sandy\r\nTim: I honestly am not the right person to ask this\r\nJames: ugh fine!"
Generated Summary: Mary's girlfriend is birthday. James and Tim are going to ask Sandy to buy her.
Ground Truth Summary: b"Mary's birthday is on Saturday. Her boyfriend, James, is looking for gift ideas. Tim suggests that he ask Sandy."
=============================
Dialogue: b"Mary: So, how's Israel? Have you been on the beach?\r\nKate: It's so expensive! But they say, it's Tel Aviv... Tomorrow we are going to Jerusalem.\r\nMary: I've heard Israel is expensive, Monica was there on vacation last year, she complained about how pricey it is. Are you going to the Dead Sea before it dies? ahahahha\r\nKate: ahahhaha yup, in few days."
Generated Summary: Kate is on vacation in Tel Aviv. Mary will visit the Dead Sea in a few days.
Ground Truth Summary: b'Mary and Kate discuss how expensive Israel is. Kate is in Tel Aviv now, planning to travel to Jerusalem tomorrow, and to the Dead Sea few days later.'
=============================
Dialogue: b"Giny: do we have rice?\r\nRiley: nope, it's finished\r\nGiny: fuck!\r\nGiny: ok, I'll buy"
Generated Summary: Giny wants to buy rice from Riley.
Ground Truth Summary: b"Giny and Riley don't have any rice left. Giny will buy some."
=============================
Dialogue: b"Jude: i'll be in warsaw at the beginning of december so we could meet again\r\nLeon: !!!\r\nLeon: at the beginning means...?\r\nLeon: cuz I won't be here during the first weekend\r\nJude: 10\r\nJude: but i think it's a monday, so never mind i guess :D\r\nLeon: yeah monday doesn't really work for me :D\r\nLeon: :<\r\nJude: oh well next time :d\r\nLeon: yeah...!"
Generated Summary: Jude and Leon will meet again this weekend at 10 am.
Ground Truth Summary: b'Jude is coming to Warsaw on the 10th of December and wants to see Leon. Leon has no time.'
=============================
生成的摘要看起来很棒!对于仅训练了一个 epoch 且只使用 5000 个示例的模型来说,效果还不错 :)