代码示例 / 生成式深度学习 / 使用 KerasHub 进行 GPT2 文本生成

使用 KerasHub 进行 GPT2 文本生成

作者: 陈乾
创建日期 2023/04/17
最后修改日期 2024/04/12
描述: 使用 KerasHub 的 GPT2 模型和 samplers 进行文本生成。

ⓘ 本示例使用 Keras 3

在 Colab 中查看 GitHub 源代码

在本教程中,您将学习如何使用 KerasHub 加载一个预训练的大型语言模型 (LLM) - GPT-2 模型 (最初由 OpenAI 发明),将其微调到特定的文本风格,并基于用户输入 (也称为提示) 生成文本。您还将了解 GPT2 如何快速适应非英语语言,例如中文。


在我们开始之前

Colab 提供不同类型的运行时。请确保转到 运行时 -> 更改运行时类型 并选择 GPU 硬件加速器运行时 (应该具有 >12G 主机 RAM 和 ~15G GPU RAM),因为您将微调 GPT-2 模型。在 CPU 运行时运行此教程将花费数小时。


安装 KerasHub、选择后端并导入依赖项

此示例使用 Keras 3"tensorflow""jax""torch" 中工作。KerasHub 内置了对 Keras 3 的支持,只需更改 "KERAS_BACKEND" 环境变量即可选择您想要的后端。我们下面选择 JAX 后端。

!pip install git+https://github.com/keras-team/keras-hub.git -q
import os

os.environ["KERAS_BACKEND"] = "jax"  # or "tensorflow" or "torch"

import keras_hub
import keras
import tensorflow as tf
import time

keras.mixed_precision.set_global_policy("mixed_float16")


生成式大型语言模型 (LLM) 简介

大型语言模型 (LLM) 是一种机器学习模型,它们在大量文本数据上进行训练,以生成各种自然语言处理 (NLP) 任务的输出,例如文本生成、问答和机器翻译。

生成式 LLM 通常基于深度学习神经网络,例如 Google 研究人员于 2017 年发明的 Transformer 架构,并在海量文本数据上进行训练,通常涉及数十亿个单词。像 Google 的 LaMDAPaLM 这样的模型,通过来自各种数据源的大型数据集进行训练,这使得它们能够为许多任务生成输出。生成式 LLM 的核心是预测句子中的下一个单词,通常称为 **因果语言模型预训练**。通过这种方式,LLM 可以根据用户提示生成连贯的文本。有关语言模型的更具教学意义的讨论,您可以参考 Stanford CS324 LLM 课程


KerasHub 简介

从头开始构建大型语言模型既复杂又昂贵。幸运的是,有现成的预训练 LLM 可供使用。 KerasHub 提供了大量的预训练检查点,让您可以直接尝试最先进的模型,而无需自己进行训练。

KerasHub 是一个自然语言处理库,可在整个开发周期中为用户提供支持。KerasHub 同时提供预训练模型和模块化的构建块,因此开发人员可以轻松地重用预训练模型或堆叠自己的 LLM。

总之,对于生成式 LLM,KerasHub 提供:


加载预训练的 GPT-2 模型并生成一些文本

KerasHub 提供了许多预训练模型,例如 Google BertGPT-2。您可以在 KerasHub 仓库 中找到可用模型的列表。

加载 GPT-2 模型非常简单,如下所示

# To speed up training and generation, we use preprocessor of length 128
# instead of full length 1024.
preprocessor = keras_hub.models.GPT2CausalLMPreprocessor.from_preset(
    "gpt2_base_en",
    sequence_length=128,
)
gpt2_lm = keras_hub.models.GPT2CausalLM.from_preset(
    "gpt2_base_en", preprocessor=preprocessor
)

模型加载后,您就可以立即使用它来生成文本了。运行下面的单元格来试试。这就像调用一个单独的函数 generate() 一样简单

start = time.time()

output = gpt2_lm.generate("My trip to Yosemite was", max_length=200)
print("\nGPT-2 output:")
print(output)

end = time.time()
print(f"TOTAL TIME ELAPSED: {end - start:.2f}s")
GPT-2 output:
My trip to Yosemite was pretty awesome. The first time I went I didn't know how to go and it was pretty hard to get around. It was a bit like going on an adventure with a friend. The only things I could do were hike and climb the mountain. It's really cool to know you're not alone in this world. It's a lot of fun. I'm a little worried that I might not get to the top of the mountain in time to see the sunrise and sunset of the day. I think the weather is going to get a little warmer in the coming years.
This post is a little more in-depth on how to go on the trail. It covers how to hike on the Sierra Nevada, how to hike with the Sierra Nevada, how to hike in the Sierra Nevada, how to get to the top of the mountain, and how to get to the top with your own gear.
The Sierra Nevada is a very popular trail in Yosemite
TOTAL TIME ELAPSED: 25.36s

再试一次

start = time.time()

output = gpt2_lm.generate("That Italian restaurant is", max_length=200)
print("\nGPT-2 output:")
print(output)

end = time.time()
print(f"TOTAL TIME ELAPSED: {end - start:.2f}s")
GPT-2 output:
That Italian restaurant is known for its delicious food, and the best part is that it has a full bar, with seating for a whole host of guests. And that's only because it's located at the heart of the neighborhood.
The menu at the Italian restaurant is pretty straightforward:
The menu consists of three main dishes:
Italian sausage
Bolognese
Sausage
Bolognese with cheese
Sauce with cream
Italian sausage with cheese
Bolognese with cheese
And the main menu consists of a few other things.
There are two tables: the one that serves a menu of sausage and bolognese with cheese (the one that serves the menu of sausage and bolognese with cheese) and the one that serves the menu of sausage and bolognese with cheese. The two tables are also open 24 hours a day, 7 days a week.
TOTAL TIME ELAPSED: 1.55s

请注意第二次调用要快得多。这是因为计算图在第一次运行时被 XLA 编译,并在后台第二次运行时被重用。

生成的文本质量看起来还可以,但我们可以通过微调来改进它。


KerasHub GPT-2 模型详解

接下来,我们将实际微调模型以更新其参数,但在那之前,让我们看一下我们拥有的用于处理 GPT2 的全套工具。

GPT2 的代码可以在 此处 找到。概念上,GPT2CausalLM 可以被分层分解为 KerasHub 中的几个模块,所有这些模块都有一个 from_preset() 函数,该函数可以加载预训练模型


在 Reddit 数据集上微调

现在您已经掌握了 KerasHub GPT-2 模型的相关知识,您可以更进一步微调模型,使其生成特定风格、长短不一、严谨或随意的文本。在本教程中,我们将以 Reddit 数据集为例。

import tensorflow_datasets as tfds

reddit_ds = tfds.load("reddit_tifu", split="train", as_supervised=True)

让我们看一下 Reddit TensorFlow 数据集中的样本数据。有两个特征

  • document:帖子的文本。
  • title:标题。
for document, title in reddit_ds:
    print(document.numpy())
    print(title.numpy())
    break
b"me and a friend decided to go to the beach last sunday. we loaded up and headed out. we were about half way there when i decided that i was not leaving till i had seafood. \n\nnow i'm not talking about red lobster. no friends i'm talking about a low country boil. i found the restaurant and got directions. i don't know if any of you have heard about the crab shack on tybee island but let me tell you it's worth it. \n\nwe arrived and was seated quickly. we decided to get a seafood sampler for two and split it. the waitress bought it out on separate platters for us. the amount of food was staggering. two types of crab, shrimp, mussels, crawfish, andouille sausage, red potatoes, and corn on the cob. i managed to finish it and some of my friends crawfish and mussels. it was a day to be a fat ass. we finished paid for our food and headed to the beach. \n\nfunny thing about seafood. it runs through me faster than a kenyan \n\nwe arrived and walked around a bit. it was about 45min since we arrived at the beach when i felt a rumble from the depths of my stomach. i ignored it i didn't want my stomach to ruin our fun. i pushed down the feeling and continued. about 15min later the feeling was back and stronger than before. again i ignored it and continued. 5min later it felt like a nuclear reactor had just exploded in my stomach. i started running. i yelled to my friend to hurry the fuck up. \n\nrunning in sand is extremely hard if you did not know this. we got in his car and i yelled at him to floor it. my stomach was screaming and if he didn't hurry i was gonna have this baby in his car and it wasn't gonna be pretty. after a few red lights and me screaming like a woman in labor we made it to the store. \n\ni practically tore his car door open and ran inside. i ran to the bathroom opened the door and barely got my pants down before the dam burst and a flood of shit poured from my ass. \n\ni finished up when i felt something wet on my ass. i rubbed it thinking it was back splash. no, mass was covered in the after math of me abusing the toilet. i grabbed all the paper towels i could and gave my self a whores bath right there. \n\ni sprayed the bathroom down with the air freshener and left. an elderly lady walked in quickly and closed the door. i was just about to walk away when i heard gag. instead of walking i ran. i got to the car and told him to get the hell out of there."
b'liking seafood'

在我们的案例中,我们正在语言模型中执行下一个单词预测,因此我们只需要 'document' 这个特征。

train_ds = (
    reddit_ds.map(lambda document, _: document)
    .batch(32)
    .cache()
    .prefetch(tf.data.AUTOTUNE)
)

现在您可以使用熟悉的 fit() 函数来微调模型。请注意,由于 GPT2CausalLM 是一个 keras_hub.models.Task 实例,因此 preprocessor 将在 fit 方法内自动调用。

这一步需要大量的 GPU 内存,如果我们要将其训练到完全训练好的状态,则需要很长时间。这里我们仅使用部分数据集作为演示。

train_ds = train_ds.take(500)
num_epochs = 1

# Linearly decaying learning rate.
learning_rate = keras.optimizers.schedules.PolynomialDecay(
    5e-5,
    decay_steps=train_ds.cardinality() * num_epochs,
    end_learning_rate=0.0,
)
loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
gpt2_lm.compile(
    optimizer=keras.optimizers.Adam(learning_rate),
    loss=loss,
    weighted_metrics=["accuracy"],
)

gpt2_lm.fit(train_ds, epochs=num_epochs)
 500/500 ━━━━━━━━━━━━━━━━━━━━ 75s 120ms/step - accuracy: 0.3189 - loss: 3.3653

<keras.src.callbacks.history.History at 0x7f2af3fda410>

微调完成后,您可以使用相同的 generate() 函数再次生成文本。这次,文本将更接近 Reddit 的写作风格,并且生成的长度将接近我们在训练集中预设的长度。

start = time.time()

output = gpt2_lm.generate("I like basketball", max_length=200)
print("\nGPT-2 output:")
print(output)

end = time.time()
print(f"TOTAL TIME ELAPSED: {end - start:.2f}s")
GPT-2 output:
I like basketball. it has the greatest shot of all time and the best shot of all time. i have to play a little bit more and get some practice time.
today i got the opportunity to play in a tournament in a city that is very close to my school so i was excited to see how it would go. i had just been playing with a few other guys, so i thought i would go and play a couple games with them. 
after a few games i was pretty confident and confident in myself. i had just gotten the opportunity and had to get some practice time. 
so i go to the
TOTAL TIME ELAPSED: 21.13s

深入采样方法

在 KerasHub 中,我们提供了一些采样方法,例如对比搜索、Top-K 和 Beam 采样。默认情况下,我们的 GPT2CausalLM 使用 Top-k 搜索,但您可以选择自己的采样方法。

与优化器和激活函数类似,有两种方法可以指定自定义采样器

  • 使用字符串标识符,例如 "greedy",通过这种方式使用默认配置。
  • 传递一个 keras_hub.samplers.Sampler 实例,您可以通过这种方式使用自定义配置。
# Use a string identifier.
gpt2_lm.compile(sampler="top_k")
output = gpt2_lm.generate("I like basketball", max_length=200)
print("\nGPT-2 output:")
print(output)

# Use a `Sampler` instance. `GreedySampler` tends to repeat itself,
greedy_sampler = keras_hub.samplers.GreedySampler()
gpt2_lm.compile(sampler=greedy_sampler)

output = gpt2_lm.generate("I like basketball", max_length=200)
print("\nGPT-2 output:")
print(output)
GPT-2 output:
I like basketball, and this is a pretty good one. 
first off, my wife is pretty good, she is a very good basketball player and she is really, really good at playing basketball. 
she has an amazing game called basketball, it is a pretty fun game. 
i play it on the couch.  i'm sitting there, watching the game on the couch.  my wife is playing with her phone.  she's playing on the phone with a bunch of people. 
my wife is sitting there and watching basketball.  she's sitting there watching
GPT-2 output:
I like basketball, but i don't like to play it. 
so i was playing basketball at my local high school, and i was playing with my friends. 
i was playing with my friends, and i was playing with my brother, who was playing basketball with his brother. 
so i was playing with my brother, and he was playing with his brother's brother. 
so i was playing with my brother, and he was playing with his brother's brother. 
so i was playing with my brother, and he was playing with his brother's brother. 
so i was playing with my brother, and he was playing with his brother's brother. 
so i was playing with my brother, and he was playing with his brother's brother. 
so i was playing with my brother, and he was playing with his brother

有关 KerasHub Sampler 类的更多详细信息,您可以查看 此处 的代码。


在中文诗歌数据集上微调

我们也可以在非英语数据集上微调 GPT2。对于懂中文的读者,这部分将说明如何在中国诗歌数据集上微调 GPT2,以教会我们的模型成为一名诗人!

由于 GPT2 使用字节对编码器,并且原始预训练数据集包含一些中文字符,因此我们可以使用原始词汇表来微调中文数据集。

!# Load chinese poetry dataset.
!git clone https://github.com/chinese-poetry/chinese-poetry.git
Cloning into 'chinese-poetry'...

从 JSON 文件加载文本。我们仅使用《全唐诗》作为演示。

import os
import json

poem_collection = []
for file in os.listdir("chinese-poetry/全唐诗"):
    if ".json" not in file or "poet" not in file:
        continue
    full_filename = "%s/%s" % ("chinese-poetry/全唐诗", file)
    with open(full_filename, "r") as f:
        content = json.load(f)
        poem_collection.extend(content)

paragraphs = ["".join(data["paragraphs"]) for data in poem_collection]

让我们看看样本数据。

print(paragraphs[0])
毋謂支山險,此山能幾何。崎嶔十年夢,知歷幾蹉跎。

与 Reddit 示例类似,我们将其转换为 TF 数据集,并仅使用部分数据进行训练。

train_ds = (
    tf.data.Dataset.from_tensor_slices(paragraphs)
    .batch(16)
    .cache()
    .prefetch(tf.data.AUTOTUNE)
)

# Running through the whole dataset takes long, only take `500` and run 1
# epochs for demo purposes.
train_ds = train_ds.take(500)
num_epochs = 1

learning_rate = keras.optimizers.schedules.PolynomialDecay(
    5e-4,
    decay_steps=train_ds.cardinality() * num_epochs,
    end_learning_rate=0.0,
)
loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
gpt2_lm.compile(
    optimizer=keras.optimizers.Adam(learning_rate),
    loss=loss,
    weighted_metrics=["accuracy"],
)

gpt2_lm.fit(train_ds, epochs=num_epochs)
 500/500 ━━━━━━━━━━━━━━━━━━━━ 49s 71ms/step - accuracy: 0.2357 - loss: 2.8196

<keras.src.callbacks.history.History at 0x7f2b2c192bc0>

让我们检查结果!

output = gpt2_lm.generate("昨夜雨疏风骤", max_length=200)
print(output)
昨夜雨疏风骤,爲臨江山院短靜。石淡山陵長爲羣,臨石山非處臨羣。美陪河埃聲爲羣,漏漏漏邊陵塘

还不错 😀