代码示例 / 生成式深度学习 / 微调 Stable Diffusion

微调 Stable Diffusion

作者: Sayak PaulChansung Park
创建日期 2022/12/28
上次修改日期 2023/01/13
描述:使用自定义图像-标题数据集微调 Stable Diffusion。

ⓘ 此示例使用 Keras 2

在 Colab 中查看 GitHub 源代码


简介

本教程展示了如何在{image, caption}对的自定义数据集上微调Stable Diffusion 模型。我们基于 Hugging Face 此处提供的微调脚本。

我们假设您对 Stable Diffusion 模型有高级理解。如果您想了解更多信息,以下资源可能会有所帮助

强烈建议您使用至少 30GB 内存的 GPU 来执行代码。

在本指南结束时,您将能够生成有趣的宝可梦图像

custom-pokemons

本教程依赖于 KerasCV 0.4.0。此外,我们需要至少 TensorFlow 2.11 才能使用带混合精度的 AdamW。

!pip install keras-cv==0.6.0 -q
!pip install -U tensorflow -q
!pip install keras-core -q

我们正在微调什么?

Stable Diffusion 模型可以分解成几个关键模型

  • 文本编码器,将输入提示投影到潜在空间。(与图像关联的标题称为“提示”。)
  • 变分自动编码器 (VAE),将输入图像投影到充当图像向量空间的潜在空间。
  • 扩散模型,根据编码的文本提示细化潜在向量并生成另一个潜在向量
  • 解码器,根据扩散模型的潜在向量生成图像。

值得注意的是,在从文本提示生成图像的过程中,通常不会使用图像编码器。

但是,在微调过程中,工作流程如下

  1. 文本编码器将输入文本提示投影到潜在空间。
  2. VAE 的图像编码器部分将输入图像投影到潜在空间。
  3. 为给定时间步长向图像潜在向量添加少量噪声。
  4. 扩散模型使用来自这两个空间的潜在向量以及时间步长嵌入来预测添加到图像潜在向量中的噪声。
  5. 在预测的噪声和步骤 3 中添加的原始噪声之间计算重建损失。
  6. 最后,使用梯度下降针对此损失优化扩散模型参数。

请注意,在微调期间仅更新扩散模型参数,而(预训练的)文本和图像编码器保持冻结状态。

如果这听起来很复杂,请不要担心。代码比这简单得多!


导入

from textwrap import wrap
import os

import keras_cv
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import tensorflow as tf
import tensorflow.experimental.numpy as tnp
from keras_cv.models.stable_diffusion.clip_tokenizer import SimpleTokenizer
from keras_cv.models.stable_diffusion.diffusion_model import DiffusionModel
from keras_cv.models.stable_diffusion.image_encoder import ImageEncoder
from keras_cv.models.stable_diffusion.noise_scheduler import NoiseScheduler
from keras_cv.models.stable_diffusion.text_encoder import TextEncoder
from tensorflow import keras

数据加载

我们使用数据集Pokémon BLIP 标题。但是,我们将使用一个略有不同的版本,该版本源自原始数据集,以便更好地与tf.data配合使用。有关更多详细信息,请参阅文档

data_path = tf.keras.utils.get_file(
    origin="https://hugging-face.cn/datasets/sayakpaul/pokemon-blip-original-version/resolve/main/pokemon_dataset.tar.gz",
    untar=True,
)

data_frame = pd.read_csv(os.path.join(data_path, "data.csv"))

data_frame["image_path"] = data_frame["image_path"].apply(
    lambda x: os.path.join(data_path, x)
)
data_frame.head()
image_path caption
0 /home/jupyter/.keras/datasets/pokemon_dataset/... 一只绿色的宝可梦的图画,眼睛是红色的
1 /home/jupyter/.keras/datasets/pokemon_dataset/... 一个绿色和黄色的玩具,有一个红色的鼻子
2 /home/jupyter/.keras/datasets/pokemon_dataset/... 一个红色和白色的球,脸上露出了愤怒的表情...
3 /home/jupyter/.keras/datasets/pokemon_dataset/... 一个卡通球,脸上带着笑容
4 /home/jupyter/.keras/datasets/pokemon_dataset/... 一堆球,上面画着人脸

由于我们只有 833 个{image, caption}对,因此我们可以预先计算标题的文本嵌入。此外,文本编码器将在微调过程中保持冻结状态,因此我们可以通过这样做来节省一些计算量。

在使用文本编码器之前,我们需要对标题进行标记化。

# The padding token and maximum prompt length are specific to the text encoder.
# If you're using a different text encoder be sure to change them accordingly.
PADDING_TOKEN = 49407
MAX_PROMPT_LENGTH = 77

# Load the tokenizer.
tokenizer = SimpleTokenizer()

#  Method to tokenize and pad the tokens.
def process_text(caption):
    tokens = tokenizer.encode(caption)
    tokens = tokens + [PADDING_TOKEN] * (MAX_PROMPT_LENGTH - len(tokens))
    return np.array(tokens)


# Collate the tokenized captions into an array.
tokenized_texts = np.empty((len(data_frame), MAX_PROMPT_LENGTH))

all_captions = list(data_frame["caption"].values)
for i, caption in enumerate(all_captions):
    tokenized_texts[i] = process_text(caption)

准备一个tf.data.Dataset

在本节中,我们将从输入图像文件路径及其对应的标题标记准备一个tf.data.Dataset对象。本节将包括以下内容

  • 预先计算标记化标题的文本嵌入。
  • 加载和增强输入图像。
  • 对数据集进行洗牌和批处理。
RESOLUTION = 256
AUTO = tf.data.AUTOTUNE
POS_IDS = tf.convert_to_tensor([list(range(MAX_PROMPT_LENGTH))], dtype=tf.int32)

augmenter = keras.Sequential(
    layers=[
        keras_cv.layers.CenterCrop(RESOLUTION, RESOLUTION),
        keras_cv.layers.RandomFlip(),
        tf.keras.layers.Rescaling(scale=1.0 / 127.5, offset=-1),
    ]
)
text_encoder = TextEncoder(MAX_PROMPT_LENGTH)


def process_image(image_path, tokenized_text):
    image = tf.io.read_file(image_path)
    image = tf.io.decode_png(image, 3)
    image = tf.image.resize(image, (RESOLUTION, RESOLUTION))
    return image, tokenized_text


def apply_augmentation(image_batch, token_batch):
    return augmenter(image_batch), token_batch


def run_text_encoder(image_batch, token_batch):
    return (
        image_batch,
        token_batch,
        text_encoder([token_batch, POS_IDS], training=False),
    )


def prepare_dict(image_batch, token_batch, encoded_text_batch):
    return {
        "images": image_batch,
        "tokens": token_batch,
        "encoded_text": encoded_text_batch,
    }


def prepare_dataset(image_paths, tokenized_texts, batch_size=1):
    dataset = tf.data.Dataset.from_tensor_slices((image_paths, tokenized_texts))
    dataset = dataset.shuffle(batch_size * 10)
    dataset = dataset.map(process_image, num_parallel_calls=AUTO).batch(batch_size)
    dataset = dataset.map(apply_augmentation, num_parallel_calls=AUTO)
    dataset = dataset.map(run_text_encoder, num_parallel_calls=AUTO)
    dataset = dataset.map(prepare_dict, num_parallel_calls=AUTO)
    return dataset.prefetch(AUTO)

基线 Stable Diffusion 模型使用 512x512 分辨率的图像进行训练。使用更高分辨率图像训练的模型不太可能很好地转移到较低分辨率的图像。但是,如果我们将分辨率保持在 512x512(不启用混合精度),则当前模型将导致 OOM。因此,为了交互式演示的利益,我们将输入分辨率保持在 256x256。

# Prepare the dataset.
training_dataset = prepare_dataset(
    np.array(data_frame["image_path"]), tokenized_texts, batch_size=4
)

# Take a sample batch and investigate.
sample_batch = next(iter(training_dataset))

for k in sample_batch:
    print(k, sample_batch[k].shape)
images (4, 256, 256, 3)
tokens (4, 77)
encoded_text (4, 77, 768)

我们还可以查看训练图像及其对应的标题。

plt.figure(figsize=(20, 10))

for i in range(3):
    ax = plt.subplot(1, 4, i + 1)
    plt.imshow((sample_batch["images"][i] + 1) / 2)

    text = tokenizer.decode(sample_batch["tokens"][i].numpy().squeeze())
    text = text.replace("<|startoftext|>", "")
    text = text.replace("<|endoftext|>", "")
    text = "\n".join(wrap(text, 12))
    plt.title(text, fontsize=15)

    plt.axis("off")

png


用于微调循环的训练器类

class Trainer(tf.keras.Model):
    # Reference:
    # https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py

    def __init__(
        self,
        diffusion_model,
        vae,
        noise_scheduler,
        use_mixed_precision=False,
        max_grad_norm=1.0,
        **kwargs
    ):
        super().__init__(**kwargs)

        self.diffusion_model = diffusion_model
        self.vae = vae
        self.noise_scheduler = noise_scheduler
        self.max_grad_norm = max_grad_norm

        self.use_mixed_precision = use_mixed_precision
        self.vae.trainable = False

    def train_step(self, inputs):
        images = inputs["images"]
        encoded_text = inputs["encoded_text"]
        batch_size = tf.shape(images)[0]

        with tf.GradientTape() as tape:
            # Project image into the latent space and sample from it.
            latents = self.sample_from_encoder_outputs(self.vae(images, training=False))
            # Know more about the magic number here:
            # https://keras.org.cn/examples/generative/fine_tune_via_textual_inversion/
            latents = latents * 0.18215

            # Sample noise that we'll add to the latents.
            noise = tf.random.normal(tf.shape(latents))

            # Sample a random timestep for each image.
            timesteps = tnp.random.randint(
                0, self.noise_scheduler.train_timesteps, (batch_size,)
            )

            # Add noise to the latents according to the noise magnitude at each timestep
            # (this is the forward diffusion process).
            noisy_latents = self.noise_scheduler.add_noise(
                tf.cast(latents, noise.dtype), noise, timesteps
            )

            # Get the target for loss depending on the prediction type
            # just the sampled noise for now.
            target = noise  # noise_schedule.predict_epsilon == True

            # Predict the noise residual and compute loss.
            timestep_embedding = tf.map_fn(
                lambda t: self.get_timestep_embedding(t), timesteps, dtype=tf.float32
            )
            timestep_embedding = tf.squeeze(timestep_embedding, 1)
            model_pred = self.diffusion_model(
                [noisy_latents, timestep_embedding, encoded_text], training=True
            )
            loss = self.compiled_loss(target, model_pred)
            if self.use_mixed_precision:
                loss = self.optimizer.get_scaled_loss(loss)

        # Update parameters of the diffusion model.
        trainable_vars = self.diffusion_model.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)
        if self.use_mixed_precision:
            gradients = self.optimizer.get_unscaled_gradients(gradients)
        gradients = [tf.clip_by_norm(g, self.max_grad_norm) for g in gradients]
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))

        return {m.name: m.result() for m in self.metrics}

    def get_timestep_embedding(self, timestep, dim=320, max_period=10000):
        half = dim // 2
        log_max_period = tf.math.log(tf.cast(max_period, tf.float32))
        freqs = tf.math.exp(
            -log_max_period * tf.range(0, half, dtype=tf.float32) / half
        )
        args = tf.convert_to_tensor([timestep], dtype=tf.float32) * freqs
        embedding = tf.concat([tf.math.cos(args), tf.math.sin(args)], 0)
        embedding = tf.reshape(embedding, [1, -1])
        return embedding

    def sample_from_encoder_outputs(self, outputs):
        mean, logvar = tf.split(outputs, 2, axis=-1)
        logvar = tf.clip_by_value(logvar, -30.0, 20.0)
        std = tf.exp(0.5 * logvar)
        sample = tf.random.normal(tf.shape(mean), dtype=mean.dtype)
        return mean + std * sample

    def save_weights(self, filepath, overwrite=True, save_format=None, options=None):
        # Overriding this method will allow us to use the `ModelCheckpoint`
        # callback directly with this trainer class. In this case, it will
        # only checkpoint the `diffusion_model` since that's what we're training
        # during fine-tuning.
        self.diffusion_model.save_weights(
            filepath=filepath,
            overwrite=overwrite,
            save_format=save_format,
            options=options,
        )

这里需要注意一个重要的实现细节:我们不是直接获取图像编码器(它是 VAE)生成的潜在向量,而是从它预测的均值和对数方差中采样。这样,我们可以获得更好的样本质量和多样性。

通常会添加对混合精度训练的支持以及模型权重的指数移动平均,以微调这些模型。但是,为了简洁起见,我们放弃了这些元素。稍后在本教程中将详细介绍。


初始化训练器并对其进行编译

# Enable mixed-precision training if the underlying GPU has tensor cores.
USE_MP = True
if USE_MP:
    keras.mixed_precision.set_global_policy("mixed_float16")

image_encoder = ImageEncoder()
diffusion_ft_trainer = Trainer(
    diffusion_model=DiffusionModel(RESOLUTION, RESOLUTION, MAX_PROMPT_LENGTH),
    # Remove the top layer from the encoder, which cuts off the variance and only
    # returns the mean.
    vae=tf.keras.Model(
        image_encoder.input,
        image_encoder.layers[-2].output,
    ),
    noise_scheduler=NoiseScheduler(),
    use_mixed_precision=USE_MP,
)

# These hyperparameters come from this tutorial by Hugging Face:
# https://hugging-face.cn/docs/diffusers/training/text2image
lr = 1e-5
beta_1, beta_2 = 0.9, 0.999
weight_decay = (1e-2,)
epsilon = 1e-08

optimizer = tf.keras.optimizers.experimental.AdamW(
    learning_rate=lr,
    weight_decay=weight_decay,
    beta_1=beta_1,
    beta_2=beta_2,
    epsilon=epsilon,
)
diffusion_ft_trainer.compile(optimizer=optimizer, loss="mse")

微调

为了使本教程的运行时间较短,我们只对一个 epoch 进行微调。

epochs = 1
ckpt_path = "finetuned_stable_diffusion.h5"
ckpt_callback = tf.keras.callbacks.ModelCheckpoint(
    ckpt_path,
    save_weights_only=True,
    monitor="loss",
    mode="min",
)
diffusion_ft_trainer.fit(training_dataset, epochs=epochs, callbacks=[ckpt_callback])

推理

我们以 512x512 的图像分辨率对模型进行了 60 个 epoch 的微调。为了允许以这种分辨率进行训练,我们集成了混合精度支持。您可以查看此存储库以了解更多详细信息。它还提供对微调模型参数的指数移动平均和模型检查点的支持。

对于本节,我们将使用 60 个 epoch 微调后得到的检查点。

weights_path = tf.keras.utils.get_file(
    origin="https://hugging-face.cn/sayakpaul/kerascv_sd_pokemon_finetuned/resolve/main/ckpt_epochs_72_res_512_mp_True.h5"
)

img_height = img_width = 512
pokemon_model = keras_cv.models.StableDiffusion(
    img_width=img_width, img_height=img_height
)
# We just reload the weights of the fine-tuned diffusion model.
pokemon_model.diffusion_model.load_weights(weights_path)
By using this model checkpoint, you acknowledge that its usage is subject to the terms of the CreativeML Open RAIL-M license at https://raw.githubusercontent.com/CompVis/stable-diffusion/main/LICENSE

现在,我们可以对这个模型进行测试。

prompts = ["Yoda", "Hello Kitty", "A pokemon with red eyes"]
images_to_generate = 3
outputs = {}

for prompt in prompts:
    generated_images = pokemon_model.text_to_image(
        prompt, batch_size=images_to_generate, unconditional_guidance_scale=40
    )
    outputs.update({prompt: generated_images})
25/25 [==============================] - 17s 231ms/step
25/25 [==============================] - 6s 229ms/step
25/25 [==============================] - 6s 229ms/step

经过 60 个 epoch 的微调(一个好的数字大约是 70),生成的图像没有达到预期水平。因此,我们尝试了 Stable Diffusion 在推理期间采取的步骤数量以及unconditional_guidance_scale参数。

我们发现使用此检查点并设置unconditional_guidance_scale为 40 可以获得最佳结果。

def plot_images(images, title):
    plt.figure(figsize=(20, 20))
    for i in range(len(images)):
        ax = plt.subplot(1, len(images), i + 1)
        plt.imshow(images[i])
        plt.title(title, fontsize=12)
        plt.axis("off")


for prompt in outputs:
    plot_images(outputs[prompt], prompt)

png

png

png

我们可以注意到,模型已开始适应我们数据集的风格。您可以查看随附的存储库以获取更多比较和评论。如果您想尝试演示,可以查看此资源


结论和致谢

我们演示了如何在自定义数据集上微调 Stable Diffusion 模型。虽然结果远非美观,但我们相信随着更多 epoch 的微调,它们可能会得到改善。为了实现这一点,支持梯度累积和分布式训练至关重要。这可以被认为是本教程的下一步。

还有另一种有趣的方法可以对 Stable Diffusion 模型进行微调,称为文本反演。您可以参考本教程以了解更多信息。

我们要感谢 Google ML 开发者计划团队提供的 GCP 积分支持。我们要感谢 Hugging Face 团队提供了微调脚本。它非常易读且易于理解。