代码示例 / Keras 快速示例 / 使用 LoRA 和 QLoRA 对 Gemma 进行参数高效微调

使用 LoRA 和 QLoRA 对 Gemma 进行参数高效微调

作者: Hongyu Chiu, Abheesht Sharma, Matthew Watson
创建日期 2024/08/06
最后修改 2024/08/06
描述: 使用 KerasHub 通过 LoRA 和 QLoRA 微调 Gemma LLM。

ⓘ 本示例使用 Keras 3

在 Colab 中查看 GitHub 源代码


引言

大型语言模型(LLM)已被证明在各种 NLP 任务中表现出色。LLM 首先通过自监督方式在一个大型文本语料库上进行预训练。预训练帮助 LLM 学习通用知识,例如词语之间的统计关系。然后可以在感兴趣的下游任务(如情感分析)上对 LLM 进行微调。

然而,LLM 的规模非常大,我们在微调时不需要训练模型中的所有参数,特别是因为用于微调的数据集相对较小。另一种说法是,LLM 对于微调来说是过参数化的。低秩适应(LoRA)正是在此发挥作用;它显著减少了可训练参数的数量。这导致训练时间和 GPU 内存使用量减少,同时保持输出的质量。

此外,量化低秩适应(QLoRA)扩展了 LoRA,通过量化技术提高效率,而不会降低性能。

在本示例中,我们将使用 LoRA 和 QLoRA,在下一词元预测任务上微调 KerasHub 的 Gemma 模型。

请注意,本示例可在 Keras 支持的所有后端上运行。TensorFlow 仅用于数据预处理。


设置

在开始实现流水线之前,让我们安装并导入所需的所有库。我们将使用 KerasHub 库。

其次,我们将精度设置为 bfloat16。这将有助于我们减少内存使用量和训练时间。

另外,请确保已正确配置 KAGGLE_USERNAMEKAGGLE_KEY 以访问 Gemma 模型。

# We might need the latest code from Keras and KerasHub
!pip install -q git+https://github.com/keras-team/keras.git git+https://github.com/keras-team/keras-hub.git
import gc
import os

os.environ["KERAS_BACKEND"] = "jax"
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"  # Suppress verbose logging from TF

# os.environ["KAGGLE_USERNAME"] = "..."
# os.environ["KAGGLE_KEY"] = "..."

import keras
import keras_hub
import tensorflow as tf
import tensorflow_datasets as tfds

keras.config.set_dtype_policy("bfloat16")

数据集

我们将使用 MTNT(噪声文本机器翻译)数据集,该数据集可从 TensorFlow Datasets 获取。在本示例中,我们将使用该数据集的法语到英语部分。

train_ds = tfds.load("mtnt/fr-en", split="train")

我们可以打印一些样本。数据集中的每个样本包含两个条目:

  • src:原始法语句子。
  • dst:对应的英语翻译。
examples = train_ds.take(3)
examples = examples.as_numpy_iterator()

for idx, example in enumerate(examples):
    print(f"Example {idx}:")
    for key, val in example.items():
        print(f"{key}: {val}")
    print()
Example 0:
dst: b'Yep, serious...'
src: b"Le journal l'est peut-\xc3\xaatre, mais m\xc3\xaame moi qui suit de droite je les trouve limite de temps en temps..."
Example 1:
dst: b'Finally, I explained to you in what context this copy-pasting is relevant: when we are told padamalgame etc.'
src: b"Enfin je t'ai expliqu\xc3\xa9 dans quel cadre ce copypasta est pertinent : quand on nous dit padamalgame etc."
Example 2:
dst: b'Gift of Ubiquity: Fran\xc3\xa7ois Baroin is now advisor to the Barclays Bank, mayor, president of the agglomeration, professor at HEC Paris, president of the Association of Mayors of France and Advocate Counselor, it must take him half a day each month.'
src: b"Don d'Ubiquit\xc3\xa9 : Fran\xc3\xa7ois Baroin est d\xc3\xa9sormais conseiller \xc3\xa0 la Banque Barclays, maire, pr\xc3\xa9sident d'agglom\xc3\xa9ration, professeur \xc3\xa0 HEC Paris, pr\xc3\xa9sident de l'association des maires de France et avocat  Conseiller, \xc3\xa7a doit lui prendre une demi journ\xc3\xa9e par mois."

由于我们将微调模型以执行法语到英语的翻译任务,因此我们应该对用于指令调整的输入进行格式化。例如,我们可以像这样格式化本示例中的翻译任务:

<start_of_turn>user
Translate French into English:
{src}<end_of_turn>
<start_of_turn>model
{dst}<end_of_turn>

user、model 和 等特殊词元用于 Gemma 模型。您可以从 https://ai.google.dev/gemma/docs/formatting 了解更多信息。

train_ds = train_ds.map(
    lambda x: tf.strings.join(
        [
            "<start_of_turn>user\n",
            "Translate French into English:\n",
            x["src"],
            "<end_of_turn>\n",
            "<start_of_turn>model\n",
            "Translation:\n",
            x["dst"],
            "<end_of_turn>",
        ]
    )
)
examples = train_ds.take(3)
examples = examples.as_numpy_iterator()

for idx, example in enumerate(examples):
    print(f"Example {idx}:")
    print(example)
    print()
Example 0:
b"<start_of_turn>user\nTranslate French into English:\nLe journal l'est peut-\xc3\xaatre, mais m\xc3\xaame moi qui suit de droite je les trouve limite de temps en temps...<end_of_turn>\n<start_of_turn>model\nTranslation:\nYep, serious...<end_of_turn>"
Example 1:
b"<start_of_turn>user\nTranslate French into English:\nEnfin je t'ai expliqu\xc3\xa9 dans quel cadre ce copypasta est pertinent : quand on nous dit padamalgame etc.<end_of_turn>\n<start_of_turn>model\nTranslation:\nFinally, I explained to you in what context this copy-pasting is relevant: when we are told padamalgame etc.<end_of_turn>"
Example 2:
b"<start_of_turn>user\nTranslate French into English:\nDon d'Ubiquit\xc3\xa9 : Fran\xc3\xa7ois Baroin est d\xc3\xa9sormais conseiller \xc3\xa0 la Banque Barclays, maire, pr\xc3\xa9sident d'agglom\xc3\xa9ration, professeur \xc3\xa0 HEC Paris, pr\xc3\xa9sident de l'association des maires de France et avocat  Conseiller, \xc3\xa7a doit lui prendre une demi journ\xc3\xa9e par mois.<end_of_turn>\n<start_of_turn>model\nTranslation:\nGift of Ubiquity: Fran\xc3\xa7ois Baroin is now advisor to the Barclays Bank, mayor, president of the agglomeration, professor at HEC Paris, president of the Association of Mayors of France and Advocate Counselor, it must take him half a day each month.<end_of_turn>"

出于本示例的目的,我们将使用数据集的一个子集。

train_ds = train_ds.batch(1).take(100)

模型

KerasHub 提供了许多流行模型架构的实现。在本示例中,我们将使用 GemmaCausalLM,这是一个用于因果语言建模的端到端 Gemma 模型。因果语言模型根据之前的词元预测下一个词元。

请注意,sequence_length 设置为 256 以加快拟合速度。

preprocessor = keras_hub.models.GemmaCausalLMPreprocessor.from_preset(
    "gemma_1.1_instruct_2b_en", sequence_length=256
)
gemma_lm = keras_hub.models.GemmaCausalLM.from_preset(
    "gemma_1.1_instruct_2b_en", preprocessor=preprocessor
)
gemma_lm.summary()
Preprocessor: "gemma_causal_lm_preprocessor"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Tokenizer (type)                                                                                Vocab # ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ gemma_tokenizer (GemmaTokenizer)                   │                                             256,000 │
└────────────────────────────────────────────────────┴─────────────────────────────────────────────────────┘
Model: "gemma_causal_lm"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Layer (type)                   Output Shape                       Param #  Connected to               ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ padding_mask (InputLayer)     │ (None, None)              │               0 │ -                          │
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤
│ token_ids (InputLayer)        │ (None, None)              │               0 │ -                          │
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤
│ gemma_backbone                │ (None, None, 2048)        │   2,506,172,416 │ padding_mask[0][0],        │
│ (GemmaBackbone)               │                           │                 │ token_ids[0][0]            │
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤
│ token_embedding               │ (None, None, 256000)      │     524,288,000 │ gemma_backbone[0][0]       │
│ (ReversibleEmbedding)         │                           │                 │                            │
└───────────────────────────────┴───────────────────────────┴─────────────────┴────────────────────────────┘
 Total params: 2,506,172,416 (4.67 GB)
 Trainable params: 2,506,172,416 (4.67 GB)
 Non-trainable params: 0 (0.00 B)

LoRA 微调

LoRA 到底是什么?

低秩适应(LoRA)是一种用于 LLM 的参数高效微调技术。它冻结 LLM 的权重,并注入可训练的秩分解矩阵。让我们更清楚地理解这一点。

假设我们有一个 n x n 的预训练全连接层(或权重矩阵)W0。我们初始化两个全连接层 A 和 B,它们的形状分别为 n x rank 和 rank x n。rank 远小于 n。在论文中,rank 值在 1 到 4 之间被证明效果良好。

LoRA 方程

原始方程是 output = W0x + b0,其中 x 是输入,W0 和 b0 是原始全连接层(已冻结)的权重矩阵和偏置项。LoRA 方程是:output = W0x + b0 + BAx,其中 A 和 B 是秩分解矩阵。

LoRA 基于这样一个思想:由于预训练语言模型是过参数化的,对预训练语言模型权重的更新具有较低的“内在秩”。即使将 W0 的更新约束在低秩分解矩阵内,也可以复制完整微调的预测性能。

可训练参数数量

让我们快速计算一下。假设 n 是 768,rank 是 4。W0 有 768 x 768 = 589,824 个参数,而 LoRA 层 A 和 B 加起来有 768 x 4 + 4 x 768 = 6,144 个参数。因此,对于全连接层,可训练参数从 589,824 个减少到 6,144 个!

LoRA 为什么会减少内存占用?

即使总参数数量增加(因为我们添加了 LoRA 层),内存占用也会减少,因为可训练参数的数量减少了。让我们深入探讨这一点。

模型的内存使用量可以分为四个部分:

  • 模型内存:这是存储模型权重所需的内存。对于 LoRA,这会比原始模型略高。
  • 前向传播内存:这主要取决于批次大小、序列长度等。为了公平比较,我们对两个模型都保持不变。
  • 反向传播内存:这是存储梯度所需的内存。请注意,梯度仅针对可训练参数计算。
  • 优化器内存:这是存储优化器状态所需的内存。例如,Adam 优化器会存储可训练参数的“一阶矩向量”和“二阶矩向量”。

由于 LoRA 极大地减少了可训练参数的数量,LoRA 的优化器内存以及存储梯度所需的内存远少于原始模型。这是大部分内存节省的发生之处。

  • 减少 GPU 内存使用;
  • 更快的训练;以及
  • 没有额外的推理延迟。

使用 KerasHub 时,我们可以通过一行代码 API 启用 LoRA:enable_lora(rank=4)

gemma_lm.summary() 中,我们可以看到启用 LoRA 显著减少了可训练参数的数量(从 25 亿减少到 130 万)。

gemma_lm.backbone.enable_lora(rank=4)
gemma_lm.summary()
Preprocessor: "gemma_causal_lm_preprocessor"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Tokenizer (type)                                                                                Vocab # ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ gemma_tokenizer (GemmaTokenizer)                   │                                             256,000 │
└────────────────────────────────────────────────────┴─────────────────────────────────────────────────────┘
Model: "gemma_causal_lm"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Layer (type)                   Output Shape                       Param #  Connected to               ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ padding_mask (InputLayer)     │ (None, None)              │               0 │ -                          │
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤
│ token_ids (InputLayer)        │ (None, None)              │               0 │ -                          │
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤
│ gemma_backbone                │ (None, None, 2048)        │   2,507,536,384 │ padding_mask[0][0],        │
│ (GemmaBackbone)               │                           │                 │ token_ids[0][0]            │
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤
│ token_embedding               │ (None, None, 256000)      │     524,288,000 │ gemma_backbone[0][0]       │
│ (ReversibleEmbedding)         │                           │                 │                            │
└───────────────────────────────┴───────────────────────────┴─────────────────┴────────────────────────────┘
 Total params: 2,507,536,384 (4.67 GB)
 Trainable params: 1,363,968 (2.60 MB)
 Non-trainable params: 2,506,172,416 (4.67 GB)

让我们微调 LoRA 模型。

# To save memory, use the SGD optimizer instead of the usual AdamW optimizer.
# For this specific example, SGD is more than enough.
optimizer = keras.optimizers.SGD(learning_rate=1e-4)
gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
gemma_lm.fit(train_ds, epochs=1)

微调后,响应将遵循提示中提供的指令。

template = (
    "<start_of_turn>user\n"
    "Translate French into English:\n"
    "{inputs}"
    "<end_of_turn>\n"
    "<start_of_turn>model\n"
    "Translation:\n"
)
prompt = template.format(inputs="Bonjour, je m'appelle Morgane.")
outputs = gemma_lm.generate(prompt, max_length=256)
print("Translation:\n", outputs.replace(prompt, ""))
Translation:
 Hello, my name is Morgane.

释放内存。

del preprocessor
del gemma_lm
del optimizer
gc.collect()

QLoRA 微调

量化低秩适应(QLoRA)扩展了 LoRA,通过将模型权重从 float32 等高精度数据类型量化到 int8 等低精度数据类型来提高效率。这会减少内存使用并加快计算速度。保存的模型权重也小得多。

请注意,这里的 QLoRA 实现与原始版本相比是简化版本。区别在于:

  • 未使用 4 位 NormalFloat 格式,因为没有后端支持。
  • 没有二次量化。
  • 没有分页优化器。

要在 KerasHub 中启用 QLoRA,请按照以下步骤操作:

  1. 实例化模型。
  2. 使用动态 int8 量化对权重进行量化。
  3. 启用 LoRA。

步骤 2 和 3 通过一行代码 API 实现:

  • quantize("int8")
  • enable_lora(...)
preprocessor = keras_hub.models.GemmaCausalLMPreprocessor.from_preset(
    "gemma_1.1_instruct_2b_en", sequence_length=256
)
gemma_lm = keras_hub.models.GemmaCausalLM.from_preset(
    "gemma_1.1_instruct_2b_en", preprocessor=preprocessor
)
gemma_lm.quantize("int8")
gemma_lm.backbone.enable_lora(rank=4)
gemma_lm.summary()
Preprocessor: "gemma_causal_lm_preprocessor_1"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Tokenizer (type)                                                                                Vocab # ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ gemma_tokenizer (GemmaTokenizer)                   │                                             256,000 │
└────────────────────────────────────────────────────┴─────────────────────────────────────────────────────┘
Model: "gemma_causal_lm_1"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Layer (type)                   Output Shape                       Param #  Connected to               ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ padding_mask (InputLayer)     │ (None, None)              │               0 │ -                          │
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤
│ token_ids (InputLayer)        │ (None, None)              │               0 │ -                          │
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤
│ gemma_backbone                │ (None, None, 2048)        │   2,508,502,016 │ padding_mask[0][0],        │
│ (GemmaBackbone)               │                           │                 │ token_ids[0][0]            │
├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤
│ token_embedding               │ (None, None, 256000)      │     524,544,000 │ gemma_backbone[0][0]       │
│ (ReversibleEmbedding)         │                           │                 │                            │
└───────────────────────────────┴───────────────────────────┴─────────────────┴────────────────────────────┘
 Total params: 2,508,502,016 (2.34 GB)
 Trainable params: 1,363,968 (2.60 MB)
 Non-trainable params: 2,507,138,048 (2.34 GB)

让我们微调 QLoRA 模型。

如果您使用的设备支持 int8 加速,您应该会看到训练速度有所提高。

optimizer = keras.optimizers.SGD(learning_rate=1e-4)
gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
gemma_lm.fit(train_ds, epochs=1)

使用 QLoRA 微调后,您应该会获得类似的输出。

prompt = template.format(inputs="Bonjour, je m'appelle Morgane.")
outputs = gemma_lm.generate(prompt, max_length=256)
print("Translation:\n", outputs.replace(prompt, ""))
Translation:
 Hello, my name is Morgane.

我们完成了!

请注意,出于演示目的,本示例仅在数据集的一个小子集上对模型进行了一次 epoch 的微调,并且使用了较低的 LoRA rank 值。为了从微调后的模型获得更好的响应,您可以尝试以下方法:

  • 增加微调数据集的大小。
  • 进行更多步骤(epoch)的训练。
  • 设置更高的 LoRA rank。
  • 修改超参数值,例如 learning_rateweight_decay