代码示例 / 生成式深度学习 / 微调 Stable Diffusion

微调 Stable Diffusion

作者: Sayak PaulChansung Park
创建日期 2022/12/28
上次修改 2023/01/13
描述: 使用自定义图像-标题数据集微调 Stable Diffusion。

ⓘ 此示例使用 Keras 2

在 Colab 中查看 GitHub 源代码


介绍

本教程演示了如何在自定义的 {image, caption} 对数据集上微调 Stable Diffusion 模型。 我们基于 Hugging Face 提供的微调脚本,可以 在此处 查看。

我们假设您对 Stable Diffusion 模型有一定的了解。 如果您想了解更多相关信息,以下资源可能会有所帮助。

强烈建议您使用至少具有 30GB 内存的 GPU 来执行代码。

在本指南结束时,您将能够生成有趣的宝可梦图像。

custom-pokemons

本教程依赖于 KerasCV 0.4.0。 此外,我们需要至少使用 TensorFlow 2.11 才能将 AdamW 与混合精度一起使用。

!pip install keras-cv==0.6.0 -q
!pip install -U tensorflow -q
!pip install keras-core -q

我们在微调什么?

Stable Diffusion 模型可以分解为几个关键模型。

  • 文本编码器,将输入提示投影到潜在空间。(与图像关联的标题称为“提示”。)
  • 变分自动编码器(VAE),将输入图像投影到潜在空间,充当图像向量空间。
  • 扩散模型,细化潜在向量并生成另一个潜在向量,以编码的文本提示为条件。
  • 解码器,根据扩散模型的潜在向量生成图像。

值得注意的是,在从文本提示生成图像的过程中,通常不使用图像编码器。

但是,在微调过程中,工作流程如下。

  1. 文本编码器将输入文本提示投影到潜在空间。
  2. 图像编码器(VAE 的一部分)将输入图像投影到潜在空间。
  3. 在给定时间步长的情况下,将少量噪声添加到图像潜在向量。
  4. 扩散模型使用来自这两个空间的潜在向量以及时间步长嵌入来预测添加到图像潜在向量中的噪声。
  5. 计算预测噪声和步骤 3 中添加的原始噪声之间的重建损失。
  6. 最后,使用梯度下降优化扩散模型参数,相对于此损失。

请注意,在微调期间,只有扩散模型参数会更新,而(预训练的)文本和图像编码器保持冻结状态。

如果这听起来很复杂,请不要担心。 代码比这简单得多!


导入

from textwrap import wrap
import os

import keras_cv
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import tensorflow as tf
import tensorflow.experimental.numpy as tnp
from keras_cv.models.stable_diffusion.clip_tokenizer import SimpleTokenizer
from keras_cv.models.stable_diffusion.diffusion_model import DiffusionModel
from keras_cv.models.stable_diffusion.image_encoder import ImageEncoder
from keras_cv.models.stable_diffusion.noise_scheduler import NoiseScheduler
from keras_cv.models.stable_diffusion.text_encoder import TextEncoder
from tensorflow import keras

数据加载

我们使用数据集 宝可梦 BLIP 标题。 但是,我们将使用一个略微不同的版本,该版本是从原始数据集中派生而来,以更好地适应 tf.data。 有关更多详细信息,请参考 文档

data_path = tf.keras.utils.get_file(
    origin="https://hugging-face.cn/datasets/sayakpaul/pokemon-blip-original-version/resolve/main/pokemon_dataset.tar.gz",
    untar=True,
)

data_frame = pd.read_csv(os.path.join(data_path, "data.csv"))

data_frame["image_path"] = data_frame["image_path"].apply(
    lambda x: os.path.join(data_path, x)
)
data_frame.head()
image_path caption
0 /home/jupyter/.keras/datasets/pokemon_dataset/... 一只红色眼睛的绿色宝可梦的图画
1 /home/jupyter/.keras/datasets/pokemon_dataset/... 一个绿色和黄色的玩具,有一个红色的鼻子
2 /home/jupyter/.keras/datasets/pokemon_dataset/... 一个红色和白色的球,脸上带着愤怒的表情...
3 /home/jupyter/.keras/datasets/pokemon_dataset/... 一个卡通球,脸上带着微笑
4 /home/jupyter/.keras/datasets/pokemon_dataset/... 一堆球,上面画着面孔

由于我们只有 833 个 {image, caption} 对,因此我们可以预先计算标题的文本嵌入。 此外,文本编码器在微调过程中将保持冻结状态,因此我们可以通过这样做节省一些计算量。

在使用文本编码器之前,我们需要对标题进行分词。

# The padding token and maximum prompt length are specific to the text encoder.
# If you're using a different text encoder be sure to change them accordingly.
PADDING_TOKEN = 49407
MAX_PROMPT_LENGTH = 77

# Load the tokenizer.
tokenizer = SimpleTokenizer()

#  Method to tokenize and pad the tokens.
def process_text(caption):
    tokens = tokenizer.encode(caption)
    tokens = tokens + [PADDING_TOKEN] * (MAX_PROMPT_LENGTH - len(tokens))
    return np.array(tokens)


# Collate the tokenized captions into an array.
tokenized_texts = np.empty((len(data_frame), MAX_PROMPT_LENGTH))

all_captions = list(data_frame["caption"].values)
for i, caption in enumerate(all_captions):
    tokenized_texts[i] = process_text(caption)

准备 tf.data.Dataset

在本节中,我们将根据输入图像文件路径及其相应的标题标记来准备 tf.data.Dataset 对象。 本节将包括以下内容。

  • 从分词的标题中预先计算文本嵌入。
  • 加载和增强输入图像。
  • 对数据集进行洗牌和批处理。
RESOLUTION = 256
AUTO = tf.data.AUTOTUNE
POS_IDS = tf.convert_to_tensor([list(range(MAX_PROMPT_LENGTH))], dtype=tf.int32)

augmenter = keras.Sequential(
    layers=[
        keras_cv.layers.CenterCrop(RESOLUTION, RESOLUTION),
        keras_cv.layers.RandomFlip(),
        tf.keras.layers.Rescaling(scale=1.0 / 127.5, offset=-1),
    ]
)
text_encoder = TextEncoder(MAX_PROMPT_LENGTH)


def process_image(image_path, tokenized_text):
    image = tf.io.read_file(image_path)
    image = tf.io.decode_png(image, 3)
    image = tf.image.resize(image, (RESOLUTION, RESOLUTION))
    return image, tokenized_text


def apply_augmentation(image_batch, token_batch):
    return augmenter(image_batch), token_batch


def run_text_encoder(image_batch, token_batch):
    return (
        image_batch,
        token_batch,
        text_encoder([token_batch, POS_IDS], training=False),
    )


def prepare_dict(image_batch, token_batch, encoded_text_batch):
    return {
        "images": image_batch,
        "tokens": token_batch,
        "encoded_text": encoded_text_batch,
    }


def prepare_dataset(image_paths, tokenized_texts, batch_size=1):
    dataset = tf.data.Dataset.from_tensor_slices((image_paths, tokenized_texts))
    dataset = dataset.shuffle(batch_size * 10)
    dataset = dataset.map(process_image, num_parallel_calls=AUTO).batch(batch_size)
    dataset = dataset.map(apply_augmentation, num_parallel_calls=AUTO)
    dataset = dataset.map(run_text_encoder, num_parallel_calls=AUTO)
    dataset = dataset.map(prepare_dict, num_parallel_calls=AUTO)
    return dataset.prefetch(AUTO)

基础 Stable Diffusion 模型使用分辨率为 512x512 的图像进行训练。 使用更高分辨率图像训练的模型不太可能很好地转移到较低分辨率的图像。 但是,如果我们将分辨率保持为 512x512(不启用混合精度),则当前模型将导致 OOM。 因此,为了实现交互式演示,我们将输入分辨率保持为 256x256。

# Prepare the dataset.
training_dataset = prepare_dataset(
    np.array(data_frame["image_path"]), tokenized_texts, batch_size=4
)

# Take a sample batch and investigate.
sample_batch = next(iter(training_dataset))

for k in sample_batch:
    print(k, sample_batch[k].shape)
images (4, 256, 256, 3)
tokens (4, 77)
encoded_text (4, 77, 768)

我们还可以查看训练图像及其相应的标题。

plt.figure(figsize=(20, 10))

for i in range(3):
    ax = plt.subplot(1, 4, i + 1)
    plt.imshow((sample_batch["images"][i] + 1) / 2)

    text = tokenizer.decode(sample_batch["tokens"][i].numpy().squeeze())
    text = text.replace("<|startoftext|>", "")
    text = text.replace("<|endoftext|>", "")
    text = "\n".join(wrap(text, 12))
    plt.title(text, fontsize=15)

    plt.axis("off")

png


用于微调循环的训练器类

class Trainer(tf.keras.Model):
    # Reference:
    # https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py

    def __init__(
        self,
        diffusion_model,
        vae,
        noise_scheduler,
        use_mixed_precision=False,
        max_grad_norm=1.0,
        **kwargs
    ):
        super().__init__(**kwargs)

        self.diffusion_model = diffusion_model
        self.vae = vae
        self.noise_scheduler = noise_scheduler
        self.max_grad_norm = max_grad_norm

        self.use_mixed_precision = use_mixed_precision
        self.vae.trainable = False

    def train_step(self, inputs):
        images = inputs["images"]
        encoded_text = inputs["encoded_text"]
        batch_size = tf.shape(images)[0]

        with tf.GradientTape() as tape:
            # Project image into the latent space and sample from it.
            latents = self.sample_from_encoder_outputs(self.vae(images, training=False))
            # Know more about the magic number here:
            # https://keras.org.cn/examples/generative/fine_tune_via_textual_inversion/
            latents = latents * 0.18215

            # Sample noise that we'll add to the latents.
            noise = tf.random.normal(tf.shape(latents))

            # Sample a random timestep for each image.
            timesteps = tnp.random.randint(
                0, self.noise_scheduler.train_timesteps, (batch_size,)
            )

            # Add noise to the latents according to the noise magnitude at each timestep
            # (this is the forward diffusion process).
            noisy_latents = self.noise_scheduler.add_noise(
                tf.cast(latents, noise.dtype), noise, timesteps
            )

            # Get the target for loss depending on the prediction type
            # just the sampled noise for now.
            target = noise  # noise_schedule.predict_epsilon == True

            # Predict the noise residual and compute loss.
            timestep_embedding = tf.map_fn(
                lambda t: self.get_timestep_embedding(t), timesteps, dtype=tf.float32
            )
            timestep_embedding = tf.squeeze(timestep_embedding, 1)
            model_pred = self.diffusion_model(
                [noisy_latents, timestep_embedding, encoded_text], training=True
            )
            loss = self.compiled_loss(target, model_pred)
            if self.use_mixed_precision:
                loss = self.optimizer.get_scaled_loss(loss)

        # Update parameters of the diffusion model.
        trainable_vars = self.diffusion_model.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)
        if self.use_mixed_precision:
            gradients = self.optimizer.get_unscaled_gradients(gradients)
        gradients = [tf.clip_by_norm(g, self.max_grad_norm) for g in gradients]
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))

        return {m.name: m.result() for m in self.metrics}

    def get_timestep_embedding(self, timestep, dim=320, max_period=10000):
        half = dim // 2
        log_max_period = tf.math.log(tf.cast(max_period, tf.float32))
        freqs = tf.math.exp(
            -log_max_period * tf.range(0, half, dtype=tf.float32) / half
        )
        args = tf.convert_to_tensor([timestep], dtype=tf.float32) * freqs
        embedding = tf.concat([tf.math.cos(args), tf.math.sin(args)], 0)
        embedding = tf.reshape(embedding, [1, -1])
        return embedding

    def sample_from_encoder_outputs(self, outputs):
        mean, logvar = tf.split(outputs, 2, axis=-1)
        logvar = tf.clip_by_value(logvar, -30.0, 20.0)
        std = tf.exp(0.5 * logvar)
        sample = tf.random.normal(tf.shape(mean), dtype=mean.dtype)
        return mean + std * sample

    def save_weights(self, filepath, overwrite=True, save_format=None, options=None):
        # Overriding this method will allow us to use the `ModelCheckpoint`
        # callback directly with this trainer class. In this case, it will
        # only checkpoint the `diffusion_model` since that's what we're training
        # during fine-tuning.
        self.diffusion_model.save_weights(
            filepath=filepath,
            overwrite=overwrite,
            save_format=save_format,
            options=options,
        )

这里需要注意的一个重要实现细节:我们不是直接获取图像编码器(即 VAE)产生的潜在向量,而是从它预测的均值和对数方差中进行采样。 这样,我们可以实现更好的样本质量和多样性。

通常,会添加对混合精度训练的支持以及模型权重的指数移动平均,以微调这些模型。 但是,为了简洁起见,我们舍弃了这些元素。 本教程稍后将对此进行更多介绍。


初始化训练器并编译它

# Enable mixed-precision training if the underlying GPU has tensor cores.
USE_MP = True
if USE_MP:
    keras.mixed_precision.set_global_policy("mixed_float16")

image_encoder = ImageEncoder()
diffusion_ft_trainer = Trainer(
    diffusion_model=DiffusionModel(RESOLUTION, RESOLUTION, MAX_PROMPT_LENGTH),
    # Remove the top layer from the encoder, which cuts off the variance and only
    # returns the mean.
    vae=tf.keras.Model(
        image_encoder.input,
        image_encoder.layers[-2].output,
    ),
    noise_scheduler=NoiseScheduler(),
    use_mixed_precision=USE_MP,
)

# These hyperparameters come from this tutorial by Hugging Face:
# https://hugging-face.cn/docs/diffusers/training/text2image
lr = 1e-5
beta_1, beta_2 = 0.9, 0.999
weight_decay = (1e-2,)
epsilon = 1e-08

optimizer = tf.keras.optimizers.experimental.AdamW(
    learning_rate=lr,
    weight_decay=weight_decay,
    beta_1=beta_1,
    beta_2=beta_2,
    epsilon=epsilon,
)
diffusion_ft_trainer.compile(optimizer=optimizer, loss="mse")

微调

为了使本教程的运行时间缩短,我们只对一个纪元进行微调。

epochs = 1
ckpt_path = "finetuned_stable_diffusion.h5"
ckpt_callback = tf.keras.callbacks.ModelCheckpoint(
    ckpt_path,
    save_weights_only=True,
    monitor="loss",
    mode="min",
)
diffusion_ft_trainer.fit(training_dataset, epochs=epochs, callbacks=[ckpt_callback])

推理

我们对模型进行了 60 个纪元的微调,图像分辨率为 512x512。 为了允许使用此分辨率进行训练,我们加入了混合精度支持。 您可以在 此仓库 中找到更多详细信息。 它还提供了对微调模型参数的指数移动平均和模型检查点的支持。

在本节中,我们将使用 60 个纪元微调后获得的检查点。

weights_path = tf.keras.utils.get_file(
    origin="https://hugging-face.cn/sayakpaul/kerascv_sd_pokemon_finetuned/resolve/main/ckpt_epochs_72_res_512_mp_True.h5"
)

img_height = img_width = 512
pokemon_model = keras_cv.models.StableDiffusion(
    img_width=img_width, img_height=img_height
)
# We just reload the weights of the fine-tuned diffusion model.
pokemon_model.diffusion_model.load_weights(weights_path)
By using this model checkpoint, you acknowledge that its usage is subject to the terms of the CreativeML Open RAIL-M license at https://raw.githubusercontent.com/CompVis/stable-diffusion/main/LICENSE

现在,我们可以对该模型进行测试驱动。

prompts = ["Yoda", "Hello Kitty", "A pokemon with red eyes"]
images_to_generate = 3
outputs = {}

for prompt in prompts:
    generated_images = pokemon_model.text_to_image(
        prompt, batch_size=images_to_generate, unconditional_guidance_scale=40
    )
    outputs.update({prompt: generated_images})
25/25 [==============================] - 17s 231ms/step
25/25 [==============================] - 6s 229ms/step
25/25 [==============================] - 6s 229ms/step

使用 60 个纪元的微调(一个好的数字大约是 70),生成的图像并不理想。 因此,我们对 Stable Diffusion 在推理期间执行的步骤数量和 unconditional_guidance_scale 参数进行了实验。

我们发现,使用此检查点和 unconditional_guidance_scale 设置为 40 时,效果最佳。

def plot_images(images, title):
    plt.figure(figsize=(20, 20))
    for i in range(len(images)):
        ax = plt.subplot(1, len(images), i + 1)
        plt.imshow(images[i])
        plt.title(title, fontsize=12)
        plt.axis("off")


for prompt in outputs:
    plot_images(outputs[prompt], prompt)

png

png

png

我们可以注意到,该模型已经开始适应我们数据集的风格。 您可以在 配套仓库 中找到更多比较和评论。 如果您想尝试演示,可以查看 此资源


结论和致谢

我们演示了如何在自定义数据集上微调 Stable Diffusion 模型。 虽然结果远非美观,但我们相信,通过更多纪元的微调,它们可能会得到改善。 为了实现这一点,支持梯度累积和分布式训练至关重要。 这可以被认为是本教程的下一步。

还有一种有趣的方法可以对 Stable Diffusion 模型进行微调,称为文本反转。 您可以参考 本教程 了解有关它的更多信息。

我们要感谢 Google 的 ML 开发者计划团队提供的 GCP 积分支持。 我们要感谢 Hugging Face 团队提供的 微调脚本。 它非常易读且易于理解。