代码示例 / 生成式深度学习 / 使用 KerasHub 进行 GPT2 文本生成

使用 KerasHub 进行 GPT2 文本生成

作者: Chen Qian
创建日期 2023/04/17
上次修改 2024/04/12
描述: 使用 KerasHub GPT2 模型和 samplers 进行文本生成。

ⓘ 此示例使用 Keras 3

在 Colab 中查看 GitHub 源码

在本教程中,你将学习如何使用 KerasHub 加载一个预训练的大型语言模型(LLM)—— GPT-2 模型(最初由 OpenAI 发明),将其微调到特定的文本风格,并根据用户输入(也称为提示)生成文本。你还将学习 GPT2 如何快速适应非英语语言,例如中文。


开始之前

Colab 提供不同类型的运行时。请务必前往运行时 -> 更改运行时类型,并选择 GPU 硬件加速器运行时(应具有 >12G 主机内存和约 15G GPU 内存),因为你需要对 GPT-2 模型进行微调。在 CPU 运行时上运行本教程将耗费数小时。


安装 KerasHub,选择后端并导入依赖项

本示例使用 Keras 3,可在 "tensorflow""jax""torch" 中的任何后端运行。Keras 3 的支持已内置于 KerasHub 中,只需更改 "KERAS_BACKEND" 环境变量即可选择你喜欢的后端。我们在下面选择了 JAX 后端。

!pip install git+https://github.com/keras-team/keras-hub.git -q
import os

os.environ["KERAS_BACKEND"] = "jax"  # or "tensorflow" or "torch"

import keras_hub
import keras
import tensorflow as tf
import time

keras.mixed_precision.set_global_policy("mixed_float16")


生成式大型语言模型 (LLMs) 简介

大型语言模型 (LLMs) 是一种机器学习模型,它们在大规模文本数据集上进行训练,以为各种自然语言处理 (NLP) 任务生成输出,例如文本生成、问答和机器翻译。

生成式 LLMs 通常基于深度学习神经网络,例如 Google 研究人员在 2017 年发明的 Transformer 架构,并在海量文本数据(通常包含数十亿词汇)上进行训练。这些模型,例如 Google 的 LaMDAPaLM,使用来自各种数据源的大型数据集进行训练,这使得它们能够为许多任务生成输出。生成式 LLMs 的核心是预测句子中的下一个词,这通常被称为因果语言模型预训练。通过这种方式,LLMs 可以根据用户提示生成连贯的文本。要更具教学性地讨论语言模型,你可以参考斯坦福大学 CS324 LLM 课程


KerasHub 简介

大型语言模型构建复杂,从头开始训练成本高昂。幸运的是,有许多预训练的 LLMs 可以直接使用。KerasHub 提供了大量预训练的检查点,让你无需自己训练即可使用最先进的模型进行实验。

KerasHub 是一个自然语言处理库,支持用户完成整个开发周期。KerasHub 提供了预训练模型和模块化构建块,因此开发者可以轻松重用预训练模型或构建自己的 LLM。

简而言之,对于生成式 LLM,KerasHub 提供


加载预训练的 GPT-2 模型并生成一些文本

KerasHub 提供了许多预训练模型,例如 Google BertGPT-2。你可以在 KerasHub 仓库中查看可用模型的列表。

加载 GPT-2 模型非常简单,如下所示

# To speed up training and generation, we use preprocessor of length 128
# instead of full length 1024.
preprocessor = keras_hub.models.GPT2CausalLMPreprocessor.from_preset(
    "gpt2_base_en",
    sequence_length=128,
)
gpt2_lm = keras_hub.models.GPT2CausalLM.from_preset(
    "gpt2_base_en", preprocessor=preprocessor
)

模型加载后,你可以立即使用它生成一些文本。运行下面的单元格试一试。这就像调用一个简单的函数 generate() 一样简单。

start = time.time()

output = gpt2_lm.generate("My trip to Yosemite was", max_length=200)
print("\nGPT-2 output:")
print(output)

end = time.time()
print(f"TOTAL TIME ELAPSED: {end - start:.2f}s")
GPT-2 output:
My trip to Yosemite was pretty awesome. The first time I went I didn't know how to go and it was pretty hard to get around. It was a bit like going on an adventure with a friend. The only things I could do were hike and climb the mountain. It's really cool to know you're not alone in this world. It's a lot of fun. I'm a little worried that I might not get to the top of the mountain in time to see the sunrise and sunset of the day. I think the weather is going to get a little warmer in the coming years.
This post is a little more in-depth on how to go on the trail. It covers how to hike on the Sierra Nevada, how to hike with the Sierra Nevada, how to hike in the Sierra Nevada, how to get to the top of the mountain, and how to get to the top with your own gear.
The Sierra Nevada is a very popular trail in Yosemite
TOTAL TIME ELAPSED: 25.36s

再试一次

start = time.time()

output = gpt2_lm.generate("That Italian restaurant is", max_length=200)
print("\nGPT-2 output:")
print(output)

end = time.time()
print(f"TOTAL TIME ELAPSED: {end - start:.2f}s")
GPT-2 output:
That Italian restaurant is known for its delicious food, and the best part is that it has a full bar, with seating for a whole host of guests. And that's only because it's located at the heart of the neighborhood.
The menu at the Italian restaurant is pretty straightforward:
The menu consists of three main dishes:
Italian sausage
Bolognese
Sausage
Bolognese with cheese
Sauce with cream
Italian sausage with cheese
Bolognese with cheese
And the main menu consists of a few other things.
There are two tables: the one that serves a menu of sausage and bolognese with cheese (the one that serves the menu of sausage and bolognese with cheese) and the one that serves the menu of sausage and bolognese with cheese. The two tables are also open 24 hours a day, 7 days a week.
TOTAL TIME ELAPSED: 1.55s

注意第二次调用快了很多。这是因为计算图在第一次运行时被 XLA 编译,并在第二次运行时在后台被重用。

生成的文本质量看起来还可以,但我们可以通过微调来改进它。


关于 KerasHub 中的 GPT-2 模型

接下来,我们将实际微调模型以更新其参数,但在开始之前,让我们先看看我们拥有的用于处理 GPT2 的全套工具。

GPT2 的代码可以在这里找到。从概念上讲,GPT2CausalLM 在 KerasHub 中可以分层分解为几个模块,所有这些模块都有一个用于加载预训练模型的 from_preset() 函数


在 Reddit 数据集上进行微调

现在你已经了解了 KerasHub 中的 GPT-2 模型,你可以更进一步对其进行微调,使其生成特定风格的文本,无论是短文本还是长文本,严格还是随意。在本教程中,我们将以 reddit 数据集为例。

import tensorflow_datasets as tfds

reddit_ds = tfds.load("reddit_tifu", split="train", as_supervised=True)

让我们看看 reddit TensorFlow 数据集中的样本数据。它有两个特征

  • document:帖子的文本。
  • title:标题。
for document, title in reddit_ds:
    print(document.numpy())
    print(title.numpy())
    break
b"me and a friend decided to go to the beach last sunday. we loaded up and headed out. we were about half way there when i decided that i was not leaving till i had seafood. \n\nnow i'm not talking about red lobster. no friends i'm talking about a low country boil. i found the restaurant and got directions. i don't know if any of you have heard about the crab shack on tybee island but let me tell you it's worth it. \n\nwe arrived and was seated quickly. we decided to get a seafood sampler for two and split it. the waitress bought it out on separate platters for us. the amount of food was staggering. two types of crab, shrimp, mussels, crawfish, andouille sausage, red potatoes, and corn on the cob. i managed to finish it and some of my friends crawfish and mussels. it was a day to be a fat ass. we finished paid for our food and headed to the beach. \n\nfunny thing about seafood. it runs through me faster than a kenyan \n\nwe arrived and walked around a bit. it was about 45min since we arrived at the beach when i felt a rumble from the depths of my stomach. i ignored it i didn't want my stomach to ruin our fun. i pushed down the feeling and continued. about 15min later the feeling was back and stronger than before. again i ignored it and continued. 5min later it felt like a nuclear reactor had just exploded in my stomach. i started running. i yelled to my friend to hurry the fuck up. \n\nrunning in sand is extremely hard if you did not know this. we got in his car and i yelled at him to floor it. my stomach was screaming and if he didn't hurry i was gonna have this baby in his car and it wasn't gonna be pretty. after a few red lights and me screaming like a woman in labor we made it to the store. \n\ni practically tore his car door open and ran inside. i ran to the bathroom opened the door and barely got my pants down before the dam burst and a flood of shit poured from my ass. \n\ni finished up when i felt something wet on my ass. i rubbed it thinking it was back splash. no, mass was covered in the after math of me abusing the toilet. i grabbed all the paper towels i could and gave my self a whores bath right there. \n\ni sprayed the bathroom down with the air freshener and left. an elderly lady walked in quickly and closed the door. i was just about to walk away when i heard gag. instead of walking i ran. i got to the car and told him to get the hell out of there."
b'liking seafood'

在本例中,我们在语言模型中执行下一个词预测,因此我们只需要 'document' 特征。

train_ds = (
    reddit_ds.map(lambda document, _: document)
    .batch(32)
    .cache()
    .prefetch(tf.data.AUTOTUNE)
)

现在你可以使用熟悉的 fit() 函数对模型进行微调。请注意,由于 GPT2CausalLMkeras_hub.models.Task 实例,因此 preprocessor 将在 fit 方法内部自动调用。

如果我们要将其训练到完全收敛,这一步会占用相当多的 GPU 内存,并且需要很长时间。这里我们只使用部分数据集进行演示。

train_ds = train_ds.take(500)
num_epochs = 1

# Linearly decaying learning rate.
learning_rate = keras.optimizers.schedules.PolynomialDecay(
    5e-5,
    decay_steps=train_ds.cardinality() * num_epochs,
    end_learning_rate=0.0,
)
loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
gpt2_lm.compile(
    optimizer=keras.optimizers.Adam(learning_rate),
    loss=loss,
    weighted_metrics=["accuracy"],
)

gpt2_lm.fit(train_ds, epochs=num_epochs)
 500/500 ━━━━━━━━━━━━━━━━━━━━ 75s 120ms/step - accuracy: 0.3189 - loss: 3.3653

<keras.src.callbacks.history.History at 0x7f2af3fda410>

微调完成后,你可以再次使用相同的 generate() 函数生成文本。这次,文本将更接近 Reddit 的写作风格,并且生成的长度将接近我们在训练集中预设的长度。

start = time.time()

output = gpt2_lm.generate("I like basketball", max_length=200)
print("\nGPT-2 output:")
print(output)

end = time.time()
print(f"TOTAL TIME ELAPSED: {end - start:.2f}s")
GPT-2 output:
I like basketball. it has the greatest shot of all time and the best shot of all time. i have to play a little bit more and get some practice time.
today i got the opportunity to play in a tournament in a city that is very close to my school so i was excited to see how it would go. i had just been playing with a few other guys, so i thought i would go and play a couple games with them. 
after a few games i was pretty confident and confident in myself. i had just gotten the opportunity and had to get some practice time. 
so i go to the
TOTAL TIME ELAPSED: 21.13s

采样方法介绍

在 KerasHub 中,我们提供了一些采样方法,例如对比搜索、Top-K 和 Beam 采样。默认情况下,我们的 GPT2CausalLM 使用 Top-k 搜索,但你可以选择自己的采样方法。

与优化器和激活函数类似,有两种方法可以指定你的自定义采样器

  • 使用字符串标识符,例如 "greedy",这样你就使用了默认配置。
  • 传递一个 keras_hub.samplers.Sampler 实例,这样你可以使用自定义配置。
# Use a string identifier.
gpt2_lm.compile(sampler="top_k")
output = gpt2_lm.generate("I like basketball", max_length=200)
print("\nGPT-2 output:")
print(output)

# Use a `Sampler` instance. `GreedySampler` tends to repeat itself,
greedy_sampler = keras_hub.samplers.GreedySampler()
gpt2_lm.compile(sampler=greedy_sampler)

output = gpt2_lm.generate("I like basketball", max_length=200)
print("\nGPT-2 output:")
print(output)
GPT-2 output:
I like basketball, and this is a pretty good one. 
first off, my wife is pretty good, she is a very good basketball player and she is really, really good at playing basketball. 
she has an amazing game called basketball, it is a pretty fun game. 
i play it on the couch.  i'm sitting there, watching the game on the couch.  my wife is playing with her phone.  she's playing on the phone with a bunch of people. 
my wife is sitting there and watching basketball.  she's sitting there watching
GPT-2 output:
I like basketball, but i don't like to play it. 
so i was playing basketball at my local high school, and i was playing with my friends. 
i was playing with my friends, and i was playing with my brother, who was playing basketball with his brother. 
so i was playing with my brother, and he was playing with his brother's brother. 
so i was playing with my brother, and he was playing with his brother's brother. 
so i was playing with my brother, and he was playing with his brother's brother. 
so i was playing with my brother, and he was playing with his brother's brother. 
so i was playing with my brother, and he was playing with his brother's brother. 
so i was playing with my brother, and he was playing with his brother

有关 KerasHub Sampler 类的更多详细信息,你可以在此处查看代码。


在中文诗歌数据集上进行微调

对于懂中文的读者,这部分演示了如何在中文诗歌数据集上微调 GPT2,让我们的模型成为一名诗人!

由于 GPT2 使用字节对编码器,并且原始预训练数据集包含一些汉字,我们可以使用原始词汇表在中文数据集上进行微调。

!# Load chinese poetry dataset.
!git clone https://github.com/chinese-poetry/chinese-poetry.git
Cloning into 'chinese-poetry'...

从 json 文件加载文本。我们仅使用《全唐诗》进行演示。

import os
import json

poem_collection = []
for file in os.listdir("chinese-poetry/全唐诗"):
    if ".json" not in file or "poet" not in file:
        continue
    full_filename = "%s/%s" % ("chinese-poetry/全唐诗", file)
    with open(full_filename, "r") as f:
        content = json.load(f)
        poem_collection.extend(content)

paragraphs = ["".join(data["paragraphs"]) for data in poem_collection]

让我们看一下样本数据。

print(paragraphs[0])
毋謂支山險,此山能幾何。崎嶔十年夢,知歷幾蹉跎。

与 Reddit 示例类似,我们转换为 TF 数据集,并且只使用部分数据进行训练。

train_ds = (
    tf.data.Dataset.from_tensor_slices(paragraphs)
    .batch(16)
    .cache()
    .prefetch(tf.data.AUTOTUNE)
)

# Running through the whole dataset takes long, only take `500` and run 1
# epochs for demo purposes.
train_ds = train_ds.take(500)
num_epochs = 1

learning_rate = keras.optimizers.schedules.PolynomialDecay(
    5e-4,
    decay_steps=train_ds.cardinality() * num_epochs,
    end_learning_rate=0.0,
)
loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
gpt2_lm.compile(
    optimizer=keras.optimizers.Adam(learning_rate),
    loss=loss,
    weighted_metrics=["accuracy"],
)

gpt2_lm.fit(train_ds, epochs=num_epochs)
 500/500 ━━━━━━━━━━━━━━━━━━━━ 49s 71ms/step - accuracy: 0.2357 - loss: 2.8196

<keras.src.callbacks.history.History at 0x7f2b2c192bc0>

让我们看看结果!

output = gpt2_lm.generate("昨夜雨疏风骤", max_length=200)
print(output)
昨夜雨疏风骤,爲臨江山院短靜。石淡山陵長爲羣,臨石山非處臨羣。美陪河埃聲爲羣,漏漏漏邊陵塘

不错 😀