代码示例 / 生成式深度学习 / 微调 Stable Diffusion

微调 Stable Diffusion

作者: Sayak Paul, Chansung Park
创建日期 2022/12/28
最后修改 2023/01/13
描述: 使用自定义图像-字幕数据集微调 Stable Diffusion。

ⓘ 此示例使用 Keras 2

在 Colab 中查看 GitHub 源代码


简介

本教程演示了如何在自定义的 {图像,字幕} 对数据集上微调 Stable Diffusion 模型。 我们基于 Hugging Face 此处提供的微调脚本构建。

我们假设您对 Stable Diffusion 模型有较高的理解。 如果您正在寻找更多相关信息,以下资源可能会有所帮助

强烈建议您使用至少 30GB 内存的 GPU 来执行代码。

在本指南结束时,您将能够生成有趣的宝可梦图像

custom-pokemons

本教程依赖于 KerasCV 0.4.0。 此外,我们需要至少 TensorFlow 2.11 才能将 AdamW 与混合精度一起使用。

!pip install keras-cv==0.6.0 -q
!pip install -U tensorflow -q
!pip install keras-core -q

我们正在微调什么?

Stable Diffusion 模型可以分解为几个关键模型

  • 一个文本编码器,将输入的提示投影到潜在空间。(与图像关联的标题被称为“提示”。)
  • 一个变分自编码器 (VAE),将输入图像投影到用作图像向量空间的潜在空间。
  • 一个扩散模型,该模型在编码的文本提示的条件下,细化潜在向量并生成另一个潜在向量
  • 一个解码器,用于生成给定来自扩散模型的潜在向量的图像。

值得注意的是,在从文本提示生成图像的过程中,通常不使用图像编码器。

但是,在微调过程中,工作流程如下所示

  1. 输入文本提示通过文本编码器投影到潜在空间。
  2. 输入图像通过 VAE 的图像编码器部分投影到潜在空间。
  3. 在给定的时间步长,少量噪声被添加到图像潜在向量中。
  4. 扩散模型使用来自这两个空间的潜在向量以及时间步长嵌入来预测添加到图像潜在中的噪声。
  5. 在预测的噪声与步骤 3 中添加的原始噪声之间计算重建损失。
  6. 最后,使用梯度下降相对于该损失优化扩散模型参数。

请注意,在微调期间仅更新扩散模型参数,而(预训练的)文本和图像编码器则保持冻结。

如果这听起来很复杂,请不要担心。 代码比这简单得多!


导入

from textwrap import wrap
import os

import keras_cv
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import tensorflow as tf
import tensorflow.experimental.numpy as tnp
from keras_cv.models.stable_diffusion.clip_tokenizer import SimpleTokenizer
from keras_cv.models.stable_diffusion.diffusion_model import DiffusionModel
from keras_cv.models.stable_diffusion.image_encoder import ImageEncoder
from keras_cv.models.stable_diffusion.noise_scheduler import NoiseScheduler
from keras_cv.models.stable_diffusion.text_encoder import TextEncoder
from tensorflow import keras

数据加载

我们使用数据集 Pokémon BLIP 字幕。 但是,我们将使用一个略有不同的版本,该版本是从原始数据集派生出来的,以便更好地与 tf.data 配合使用。 有关更多详细信息,请参阅 文档

data_path = tf.keras.utils.get_file(
    origin="https://hugging-face.cn/datasets/sayakpaul/pokemon-blip-original-version/resolve/main/pokemon_dataset.tar.gz",
    untar=True,
)

data_frame = pd.read_csv(os.path.join(data_path, "data.csv"))

data_frame["image_path"] = data_frame["image_path"].apply(
    lambda x: os.path.join(data_path, x)
)
data_frame.head()
image_path 字幕
0 /home/jupyter/.keras/datasets/pokemon_dataset/... 一个画着绿色宝可梦的图画,它有红眼睛
1 /home/jupyter/.keras/datasets/pokemon_dataset/... 一个绿色和黄色的玩具,上面有一个红鼻子
2 /home/jupyter/.keras/datasets/pokemon_dataset/... 一个红色和白色的球,上面画着愤怒的表情...
3 /home/jupyter/.keras/datasets/pokemon_dataset/... 一个卡通球,脸上带着微笑
4 /home/jupyter/.keras/datasets/pokemon_dataset/... 一堆脸上画着图案的球

由于我们只有 833 个 {图像,字幕} 对,因此我们可以从字幕中预先计算文本嵌入。 此外,文本编码器在微调过程中将保持冻结,因此我们可以通过这样做来节省一些计算量。

在使用文本编码器之前,我们需要对字幕进行标记化。

# The padding token and maximum prompt length are specific to the text encoder.
# If you're using a different text encoder be sure to change them accordingly.
PADDING_TOKEN = 49407
MAX_PROMPT_LENGTH = 77

# Load the tokenizer.
tokenizer = SimpleTokenizer()

#  Method to tokenize and pad the tokens.
def process_text(caption):
    tokens = tokenizer.encode(caption)
    tokens = tokens + [PADDING_TOKEN] * (MAX_PROMPT_LENGTH - len(tokens))
    return np.array(tokens)


# Collate the tokenized captions into an array.
tokenized_texts = np.empty((len(data_frame), MAX_PROMPT_LENGTH))

all_captions = list(data_frame["caption"].values)
for i, caption in enumerate(all_captions):
    tokenized_texts[i] = process_text(caption)

准备一个 tf.data.Dataset

在本节中,我们将从输入图像文件路径及其对应的字幕标记准备一个 tf.data.Dataset 对象。 该部分将包括以下内容

  • 从标记化的字幕预先计算文本嵌入。
  • 加载和增强输入图像。
  • 数据集的随机排序和批处理。
RESOLUTION = 256
AUTO = tf.data.AUTOTUNE
POS_IDS = tf.convert_to_tensor([list(range(MAX_PROMPT_LENGTH))], dtype=tf.int32)

augmenter = keras.Sequential(
    layers=[
        keras_cv.layers.CenterCrop(RESOLUTION, RESOLUTION),
        keras_cv.layers.RandomFlip(),
        tf.keras.layers.Rescaling(scale=1.0 / 127.5, offset=-1),
    ]
)
text_encoder = TextEncoder(MAX_PROMPT_LENGTH)


def process_image(image_path, tokenized_text):
    image = tf.io.read_file(image_path)
    image = tf.io.decode_png(image, 3)
    image = tf.image.resize(image, (RESOLUTION, RESOLUTION))
    return image, tokenized_text


def apply_augmentation(image_batch, token_batch):
    return augmenter(image_batch), token_batch


def run_text_encoder(image_batch, token_batch):
    return (
        image_batch,
        token_batch,
        text_encoder([token_batch, POS_IDS], training=False),
    )


def prepare_dict(image_batch, token_batch, encoded_text_batch):
    return {
        "images": image_batch,
        "tokens": token_batch,
        "encoded_text": encoded_text_batch,
    }


def prepare_dataset(image_paths, tokenized_texts, batch_size=1):
    dataset = tf.data.Dataset.from_tensor_slices((image_paths, tokenized_texts))
    dataset = dataset.shuffle(batch_size * 10)
    dataset = dataset.map(process_image, num_parallel_calls=AUTO).batch(batch_size)
    dataset = dataset.map(apply_augmentation, num_parallel_calls=AUTO)
    dataset = dataset.map(run_text_encoder, num_parallel_calls=AUTO)
    dataset = dataset.map(prepare_dict, num_parallel_calls=AUTO)
    return dataset.prefetch(AUTO)

基线 Stable Diffusion 模型是使用 512x512 分辨率的图像进行训练的。 使用较高分辨率图像训练的模型不太可能很好地转移到较低分辨率的图像。 但是,如果我们将分辨率保持为 512x512(不启用混合精度),则当前模型将导致 OOM。 因此,为了进行交互式演示,我们将输入分辨率保持为 256x256。

# Prepare the dataset.
training_dataset = prepare_dataset(
    np.array(data_frame["image_path"]), tokenized_texts, batch_size=4
)

# Take a sample batch and investigate.
sample_batch = next(iter(training_dataset))

for k in sample_batch:
    print(k, sample_batch[k].shape)
images (4, 256, 256, 3)
tokens (4, 77)
encoded_text (4, 77, 768)

我们还可以看一下训练图像及其对应的字幕。

plt.figure(figsize=(20, 10))

for i in range(3):
    ax = plt.subplot(1, 4, i + 1)
    plt.imshow((sample_batch["images"][i] + 1) / 2)

    text = tokenizer.decode(sample_batch["tokens"][i].numpy().squeeze())
    text = text.replace("<|startoftext|>", "")
    text = text.replace("<|endoftext|>", "")
    text = "\n".join(wrap(text, 12))
    plt.title(text, fontsize=15)

    plt.axis("off")

png


微调循环的训练器类

class Trainer(tf.keras.Model):
    # Reference:
    # https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py

    def __init__(
        self,
        diffusion_model,
        vae,
        noise_scheduler,
        use_mixed_precision=False,
        max_grad_norm=1.0,
        **kwargs
    ):
        super().__init__(**kwargs)

        self.diffusion_model = diffusion_model
        self.vae = vae
        self.noise_scheduler = noise_scheduler
        self.max_grad_norm = max_grad_norm

        self.use_mixed_precision = use_mixed_precision
        self.vae.trainable = False

    def train_step(self, inputs):
        images = inputs["images"]
        encoded_text = inputs["encoded_text"]
        batch_size = tf.shape(images)[0]

        with tf.GradientTape() as tape:
            # Project image into the latent space and sample from it.
            latents = self.sample_from_encoder_outputs(self.vae(images, training=False))
            # Know more about the magic number here:
            # https://keras.org.cn/examples/generative/fine_tune_via_textual_inversion/
            latents = latents * 0.18215

            # Sample noise that we'll add to the latents.
            noise = tf.random.normal(tf.shape(latents))

            # Sample a random timestep for each image.
            timesteps = tnp.random.randint(
                0, self.noise_scheduler.train_timesteps, (batch_size,)
            )

            # Add noise to the latents according to the noise magnitude at each timestep
            # (this is the forward diffusion process).
            noisy_latents = self.noise_scheduler.add_noise(
                tf.cast(latents, noise.dtype), noise, timesteps
            )

            # Get the target for loss depending on the prediction type
            # just the sampled noise for now.
            target = noise  # noise_schedule.predict_epsilon == True

            # Predict the noise residual and compute loss.
            timestep_embedding = tf.map_fn(
                lambda t: self.get_timestep_embedding(t), timesteps, dtype=tf.float32
            )
            timestep_embedding = tf.squeeze(timestep_embedding, 1)
            model_pred = self.diffusion_model(
                [noisy_latents, timestep_embedding, encoded_text], training=True
            )
            loss = self.compiled_loss(target, model_pred)
            if self.use_mixed_precision:
                loss = self.optimizer.get_scaled_loss(loss)

        # Update parameters of the diffusion model.
        trainable_vars = self.diffusion_model.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)
        if self.use_mixed_precision:
            gradients = self.optimizer.get_unscaled_gradients(gradients)
        gradients = [tf.clip_by_norm(g, self.max_grad_norm) for g in gradients]
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))

        return {m.name: m.result() for m in self.metrics}

    def get_timestep_embedding(self, timestep, dim=320, max_period=10000):
        half = dim // 2
        log_max_period = tf.math.log(tf.cast(max_period, tf.float32))
        freqs = tf.math.exp(
            -log_max_period * tf.range(0, half, dtype=tf.float32) / half
        )
        args = tf.convert_to_tensor([timestep], dtype=tf.float32) * freqs
        embedding = tf.concat([tf.math.cos(args), tf.math.sin(args)], 0)
        embedding = tf.reshape(embedding, [1, -1])
        return embedding

    def sample_from_encoder_outputs(self, outputs):
        mean, logvar = tf.split(outputs, 2, axis=-1)
        logvar = tf.clip_by_value(logvar, -30.0, 20.0)
        std = tf.exp(0.5 * logvar)
        sample = tf.random.normal(tf.shape(mean), dtype=mean.dtype)
        return mean + std * sample

    def save_weights(self, filepath, overwrite=True, save_format=None, options=None):
        # Overriding this method will allow us to use the `ModelCheckpoint`
        # callback directly with this trainer class. In this case, it will
        # only checkpoint the `diffusion_model` since that's what we're training
        # during fine-tuning.
        self.diffusion_model.save_weights(
            filepath=filepath,
            overwrite=overwrite,
            save_format=save_format,
            options=options,
        )

此处需要注意的一个重要实现细节:我们不是直接获取图像编码器(即 VAE)产生的潜在向量,而是从其预测的均值和对数方差中采样。 这样,我们可以获得更好的样本质量和多样性。

通常会添加对混合精度训练的支持以及模型权重的指数移动平均值来微调这些模型。 但是,为了简洁起见,我们放弃了这些元素。 稍后在本教程中对此进行更多介绍。


初始化训练器并编译它

# Enable mixed-precision training if the underlying GPU has tensor cores.
USE_MP = True
if USE_MP:
    keras.mixed_precision.set_global_policy("mixed_float16")

image_encoder = ImageEncoder()
diffusion_ft_trainer = Trainer(
    diffusion_model=DiffusionModel(RESOLUTION, RESOLUTION, MAX_PROMPT_LENGTH),
    # Remove the top layer from the encoder, which cuts off the variance and only
    # returns the mean.
    vae=tf.keras.Model(
        image_encoder.input,
        image_encoder.layers[-2].output,
    ),
    noise_scheduler=NoiseScheduler(),
    use_mixed_precision=USE_MP,
)

# These hyperparameters come from this tutorial by Hugging Face:
# https://hugging-face.cn/docs/diffusers/training/text2image
lr = 1e-5
beta_1, beta_2 = 0.9, 0.999
weight_decay = (1e-2,)
epsilon = 1e-08

optimizer = tf.keras.optimizers.experimental.AdamW(
    learning_rate=lr,
    weight_decay=weight_decay,
    beta_1=beta_1,
    beta_2=beta_2,
    epsilon=epsilon,
)
diffusion_ft_trainer.compile(optimizer=optimizer, loss="mse")

微调

为了使本教程的运行时保持较短,我们仅微调一个 epoch。

epochs = 1
ckpt_path = "finetuned_stable_diffusion.h5"
ckpt_callback = tf.keras.callbacks.ModelCheckpoint(
    ckpt_path,
    save_weights_only=True,
    monitor="loss",
    mode="min",
)
diffusion_ft_trainer.fit(training_dataset, epochs=epochs, callbacks=[ckpt_callback])

推理

我们在 512x512 的图像分辨率上将模型微调了 60 个 epoch。 为了允许使用此分辨率进行训练,我们加入了混合精度支持。 您可以查看 此存储库 以获取更多详细信息。 它还为微调模型参数和模型检查点提供了指数移动平均值的支持。

在本节中,我们将使用微调 60 个 epoch 后得到的检查点。

weights_path = tf.keras.utils.get_file(
    origin="https://hugging-face.cn/sayakpaul/kerascv_sd_pokemon_finetuned/resolve/main/ckpt_epochs_72_res_512_mp_True.h5"
)

img_height = img_width = 512
pokemon_model = keras_cv.models.StableDiffusion(
    img_width=img_width, img_height=img_height
)
# We just reload the weights of the fine-tuned diffusion model.
pokemon_model.diffusion_model.load_weights(weights_path)
By using this model checkpoint, you acknowledge that its usage is subject to the terms of the CreativeML Open RAIL-M license at https://raw.githubusercontent.com/CompVis/stable-diffusion/main/LICENSE

现在,我们可以对这个模型进行测试了。

prompts = ["Yoda", "Hello Kitty", "A pokemon with red eyes"]
images_to_generate = 3
outputs = {}

for prompt in prompts:
    generated_images = pokemon_model.text_to_image(
        prompt, batch_size=images_to_generate, unconditional_guidance_scale=40
    )
    outputs.update({prompt: generated_images})
25/25 [==============================] - 17s 231ms/step
25/25 [==============================] - 6s 229ms/step
25/25 [==============================] - 6s 229ms/step

经过 60 个 epoch 的微调(一个好的数字约为 70),生成的图像达不到标准。 因此,我们尝试了 Stable Diffusion 在推理时间内使用的步数以及 unconditional_guidance_scale 参数。

我们发现,将此检查点的 unconditional_guidance_scale 设置为 40 时,结果最好。

def plot_images(images, title):
    plt.figure(figsize=(20, 20))
    for i in range(len(images)):
        ax = plt.subplot(1, len(images), i + 1)
        plt.imshow(images[i])
        plt.title(title, fontsize=12)
        plt.axis("off")


for prompt in outputs:
    plot_images(outputs[prompt], prompt)

png

png

png

我们可以注意到该模型已经开始适应我们数据集的风格。 您可以查看附带的存储库,了解更多比较和评论。 如果您想尝试演示,可以查看 此资源


结论和致谢

我们演示了如何在自定义数据集上微调 Stable Diffusion 模型。 虽然结果远非美观,但我们相信随着更多 epoch 的微调,它们可能会有所改善。 为了实现这一目标,支持梯度累积和分布式训练至关重要。 这可以被认为是本教程的下一步。

还有另一种有趣的微调 Stable Diffusion 模型的方法,称为文本反演。 您可以参考 本教程 以了解更多信息。

我们要感谢 Google ML 开发者项目团队提供的 GCP 信用支持。 我们要感谢 Hugging Face 团队提供 微调脚本。 它非常易读易懂。