代码示例 / 生成式深度学习 / 使用 Stable Diffusion 3 漫步隐空间

使用 Stable Diffusion 3 漫步隐空间

作者: Hongyu Chiu, Ian Stenbit, fchollet, lukewood
创建日期 2024/11/11
最后修改日期 2024/11/11
描述: 探索 Stable Diffusion 3 的隐空间流形。

ⓘ 此示例使用 Keras 3

在 Colab 中查看 GitHub 源代码


概述

生成式图像模型学习视觉世界的“隐空间流形”:一个低维向量空间,其中每个点映射到一个图像。从流形上的某个点回到可显示的图像称为“解码”——在 Stable Diffusion 模型中,这由“解码器”模型处理。

Stable Diffusion 3 Medium Architecture

这个图像的隐空间流形是连续且可插值的,这意味着

  1. 在流形上移动一点只会使对应的图像变化一点(连续性)。
  2. 对于流形上的任意两点 A 和 B(即任意两个图像),可以通过一条路径从 A 移动到 B,路径上的每个中间点也都在流形上(即也是一个有效的图像)。中间点被称为两个起始图像之间的“插值”。

Stable Diffusion 不仅仅是一个图像模型,它也是一个自然语言模型。它有两个隐空间:训练期间编码器学习的图像表示空间,以及结合预训练和训练时微调学习到的提示词隐空间。

隐空间漫步,或称隐空间探索,是指在隐空间中采样一个点并逐步改变隐表示的过程。其最常见的应用是生成动画,其中每个采样点被送入解码器并作为最终动画的一帧存储。对于高质量的隐表示,这会产生连贯的动画。这些动画可以提供对隐空间特征图的洞察,并最终改进训练过程。下面展示了这样一个 GIF

dog_to_cat_64.gif

在本指南中,我们将展示如何利用 KerasHub 中的 TextToImage API 对 Stable Diffusion 3 的视觉隐空间流形以及文本编码器的隐空间进行提示词插值和循环漫步。

本指南假定读者对 Stable Diffusion 3 有一定的了解。如果你还不了解,建议先阅读 KerasHub 中的 Stable Diffusion 3

另外值得注意的是,预设的 "stable_diffusion_3_medium" 不包含 T5XXL 文本编码器,因为它需要显着更多的 GPU 内存。在大多数情况下,性能下降可以忽略不计。包括 T5XXL 在内的权重将很快在 KerasHub 上提供。

!# Use the latest version of KerasHub
!!pip install -Uq git+https://github.com/keras-team/keras-hub.git
import math

import keras
import keras_hub
import matplotlib.pyplot as plt
from keras import ops
from keras import random
from PIL import Image

height, width = 512, 512
num_steps = 28
guidance_scale = 7.0
dtype = "float16"

# Instantiate the Stable Diffusion 3 model and the preprocessor
backbone = keras_hub.models.StableDiffusion3Backbone.from_preset(
    "stable_diffusion_3_medium", image_shape=(height, width, 3), dtype=dtype
)
preprocessor = keras_hub.models.StableDiffusion3TextToImagePreprocessor.from_preset(
    "stable_diffusion_3_medium"
)

让我们为本示例定义一些辅助函数。

def get_text_embeddings(prompt):
    """Get the text embeddings for a given prompt."""
    token_ids = preprocessor.generate_preprocess([prompt])
    negative_token_ids = preprocessor.generate_preprocess([""])
    (
        positive_embeddings,
        negative_embeddings,
        positive_pooled_embeddings,
        negative_pooled_embeddings,
    ) = backbone.encode_text_step(token_ids, negative_token_ids)
    return (
        positive_embeddings,
        negative_embeddings,
        positive_pooled_embeddings,
        negative_pooled_embeddings,
    )


def decode_to_images(x, height, width):
    """Concatenate and normalize the images to uint8 dtype."""
    x = ops.concatenate(x, axis=0)
    x = ops.reshape(x, (-1, height, width, 3))
    x = ops.clip(ops.divide(ops.add(x, 1.0), 2.0), 0.0, 1.0)
    return ops.cast(ops.round(ops.multiply(x, 255.0)), "uint8")


def generate_with_latents_and_embeddings(
    latents, embeddings, num_steps, guidance_scale
):
    """Generate images from latents and text embeddings."""

    def body_fun(step, latents):
        return backbone.denoise_step(
            latents,
            embeddings,
            step,
            num_steps,
            guidance_scale,
        )

    latents = ops.fori_loop(0, num_steps, body_fun, latents)
    return backbone.decode_step(latents)


def export_as_gif(filename, images, frames_per_second=10, no_rubber_band=False):
    if not no_rubber_band:
        images += images[2:-1][::-1]  # Makes a rubber band: A->B->A
    images[0].save(
        filename,
        save_all=True,
        append_images=images[1:],
        duration=1000 // frames_per_second,
        loop=0,
    )

我们将使用自定义的隐向量和嵌入来生成图像,因此需要实现 generate_with_latents_and_embeddings 函数。此外,编译此函数以加速生成过程也很重要。

if keras.config.backend() == "torch":
    import torch

    @torch.no_grad()
    def wrapped_function(*args, **kwargs):
        return generate_with_latents_and_embeddings(*args, **kwargs)

    generate_function = wrapped_function
elif keras.config.backend() == "tensorflow":
    import tensorflow as tf

    generate_function = tf.function(
        generate_with_latents_and_embeddings, jit_compile=True
    )
elif keras.config.backend() == "jax":
    import itertools

    import jax

    @jax.jit
    def compiled_function(state, *args, **kwargs):
        (trainable_variables, non_trainable_variables) = state
        mapping = itertools.chain(
            zip(backbone.trainable_variables, trainable_variables),
            zip(backbone.non_trainable_variables, non_trainable_variables),
        )
        with keras.StatelessScope(state_mapping=mapping):
            return generate_with_latents_and_embeddings(*args, **kwargs)

    def wrapped_function(*args, **kwargs):
        state = (
            [v.value for v in backbone.trainable_variables],
            [v.value for v in backbone.non_trainable_variables],
        )
        return compiled_function(state, *args, **kwargs)

    generate_function = wrapped_function

在文本提示词之间进行插值

在 Stable Diffusion 3 中,文本提示词被编码成多个向量,然后用于指导扩散过程。这些隐编码向量对于正向和负向提示词的形状分别为 154x4096 和 2048——相当大!当我们向 Stable Diffusion 3 输入文本提示词时,我们从这个隐空间流形上的一个点生成图像。

为了探索更多这个流形,我们可以在两个文本编码之间进行插值,并在这些插值点生成图像

prompt_1 = "A cute dog in a beautiful field of lavander colorful flowers "
prompt_1 += "everywhere, perfect lighting, leica summicron 35mm f2.0, kodak "
prompt_1 += "portra 400, film grain"
prompt_2 = prompt_1.replace("dog", "cat")
interpolation_steps = 5

encoding_1 = get_text_embeddings(prompt_1)
encoding_2 = get_text_embeddings(prompt_2)


# Show the size of the latent manifold
print(f"Positive embeddings shape: {encoding_1[0].shape}")
print(f"Negative embeddings shape: {encoding_1[1].shape}")
print(f"Positive pooled embeddings shape: {encoding_1[2].shape}")
print(f"Negative pooled embeddings shape: {encoding_1[3].shape}")
Positive embeddings shape: (1, 154, 4096)
Negative embeddings shape: (1, 154, 4096)
Positive pooled embeddings shape: (1, 2048)
Negative pooled embeddings shape: (1, 2048)

在此示例中,我们希望使用球面线性插值 (slerp) 而不是简单的线性插值。Slerp 通常用于计算机图形学中以平滑地制作旋转动画,也可应用于生成模型中使用的隐向量等高维数据点的插值。

来源是 Andrej Karpathy 的 gist:https://gist.github.com/karpathy/00103b0037c5aaea32fe1da1af553355

关于此方法的更详细解释,请参见:https://en.wikipedia.org/wiki/Slerp

def slerp(v1, v2, num):
    ori_dtype = v1.dtype
    # Cast to float32 for numerical stability.
    v1 = ops.cast(v1, "float32")
    v2 = ops.cast(v2, "float32")

    def interpolation(t, v1, v2, dot_threshold=0.9995):
        """helper function to spherically interpolate two arrays."""
        dot = ops.sum(
            v1 * v2 / (ops.linalg.norm(ops.ravel(v1)) * ops.linalg.norm(ops.ravel(v2)))
        )
        if ops.abs(dot) > dot_threshold:
            v2 = (1 - t) * v1 + t * v2
        else:
            theta_0 = ops.arccos(dot)
            sin_theta_0 = ops.sin(theta_0)
            theta_t = theta_0 * t
            sin_theta_t = ops.sin(theta_t)
            s0 = ops.sin(theta_0 - theta_t) / sin_theta_0
            s1 = sin_theta_t / sin_theta_0
            v2 = s0 * v1 + s1 * v2
        return v2

    t = ops.linspace(0, 1, num)
    interpolated = ops.stack([interpolation(t[i], v1, v2) for i in range(num)], axis=0)
    return ops.cast(interpolated, ori_dtype)


interpolated_positive_embeddings = slerp(
    encoding_1[0], encoding_2[0], interpolation_steps
)
interpolated_positive_pooled_embeddings = slerp(
    encoding_1[2], encoding_2[2], interpolation_steps
)
# We don't use negative prompts in this example, so there’s no need to
# interpolate them.
negative_embeddings = encoding_1[1]
negative_pooled_embeddings = encoding_1[3]

对编码进行插值后,我们可以从每个点生成图像。请注意,为了保持生成图像之间的一致性,我们在图像之间保持扩散隐向量不变。

latents = random.normal((1, height // 8, width // 8, 16), seed=42)

images = []
progbar = keras.utils.Progbar(interpolation_steps)
for i in range(interpolation_steps):
    images.append(
        generate_function(
            latents,
            (
                interpolated_positive_embeddings[i],
                negative_embeddings,
                interpolated_positive_pooled_embeddings[i],
                negative_pooled_embeddings,
            ),
            ops.convert_to_tensor(num_steps),
            ops.convert_to_tensor(guidance_scale),
        )
    )
    progbar.update(i + 1, finalize=i == interpolation_steps - 1)

现在我们已经生成了一些插值图像,让我们来看看它们!

在本教程中,我们将把图像序列导出为 GIF,以便轻松查看并带有时间上下文。对于概念上首尾图像不匹配的序列,我们将 GIF 进行“橡皮筋”处理(即来回播放)。

如果你在 Colab 中运行,可以通过运行以下代码查看自己的 GIF

from IPython.display import Image as IImage
IImage("dog_to_cat_5.gif")
images = ops.convert_to_numpy(decode_to_images(images, height, width))
export_as_gif(
    "dog_to_cat_5.gif",
    [Image.fromarray(image) for image in images],
    frames_per_second=2,
)

dog_to_cat_5.gif

结果可能看起来令人惊讶。一般来说,在提示词之间进行插值会产生连贯的图像,并且通常会展示两个提示词内容之间的概念逐步转变。这表明表示空间质量很高,它密切反映了视觉世界的自然结构。

为了更好地可视化这一点,我们应该进行更精细的插值,使用更多的步骤。

interpolation_steps = 64
batch_size = 4
batches = interpolation_steps // batch_size

interpolated_positive_embeddings = slerp(
    encoding_1[0], encoding_2[0], interpolation_steps
)
interpolated_positive_pooled_embeddings = slerp(
    encoding_1[2], encoding_2[2], interpolation_steps
)
positive_embeddings_shape = ops.shape(encoding_1[0])
positive_pooled_embeddings_shape = ops.shape(encoding_1[2])
interpolated_positive_embeddings = ops.reshape(
    interpolated_positive_embeddings,
    (
        batches,
        batch_size,
        positive_embeddings_shape[-2],
        positive_embeddings_shape[-1],
    ),
)
interpolated_positive_pooled_embeddings = ops.reshape(
    interpolated_positive_pooled_embeddings,
    (batches, batch_size, positive_pooled_embeddings_shape[-1]),
)
negative_embeddings = ops.tile(encoding_1[1], (batch_size, 1, 1))
negative_pooled_embeddings = ops.tile(encoding_1[3], (batch_size, 1))

latents = random.normal((1, height // 8, width // 8, 16), seed=42)
latents = ops.tile(latents, (batch_size, 1, 1, 1))

images = []
progbar = keras.utils.Progbar(batches)
for i in range(batches):
    images.append(
        generate_function(
            latents,
            (
                interpolated_positive_embeddings[i],
                negative_embeddings,
                interpolated_positive_pooled_embeddings[i],
                negative_pooled_embeddings,
            ),
            ops.convert_to_tensor(num_steps),
            ops.convert_to_tensor(guidance_scale),
        )
    )
    progbar.update(i + 1, finalize=i == batches - 1)

images = ops.convert_to_numpy(decode_to_images(images, height, width))
export_as_gif(
    "dog_to_cat_64.gif",
    [Image.fromarray(image) for image in images],
    frames_per_second=2,
)

dog_to_cat_64.gif

生成的 GIF 显示了两个提示词之间更清晰、更连贯的转变。尝试输入你自己的提示词并进行实验吧!

我们甚至可以将此概念扩展到多个图像。例如,我们可以在四个提示词之间进行插值

prompt_1 = "A watercolor painting of a Golden Retriever at the beach"
prompt_2 = "A still life DSLR photo of a bowl of fruit"
prompt_3 = "The eiffel tower in the style of starry night"
prompt_4 = "An architectural sketch of a skyscraper"

interpolation_steps = 8
batch_size = 4
batches = (interpolation_steps**2) // batch_size

encoding_1 = get_text_embeddings(prompt_1)
encoding_2 = get_text_embeddings(prompt_2)
encoding_3 = get_text_embeddings(prompt_3)
encoding_4 = get_text_embeddings(prompt_4)

positive_embeddings_shape = ops.shape(encoding_1[0])
positive_pooled_embeddings_shape = ops.shape(encoding_1[2])
interpolated_positive_embeddings_12 = slerp(
    encoding_1[0], encoding_2[0], interpolation_steps
)
interpolated_positive_embeddings_34 = slerp(
    encoding_3[0], encoding_4[0], interpolation_steps
)
interpolated_positive_embeddings = slerp(
    interpolated_positive_embeddings_12,
    interpolated_positive_embeddings_34,
    interpolation_steps,
)
interpolated_positive_embeddings = ops.reshape(
    interpolated_positive_embeddings,
    (
        batches,
        batch_size,
        positive_embeddings_shape[-2],
        positive_embeddings_shape[-1],
    ),
)
interpolated_positive_pooled_embeddings_12 = slerp(
    encoding_1[2], encoding_2[2], interpolation_steps
)
interpolated_positive_pooled_embeddings_34 = slerp(
    encoding_3[2], encoding_4[2], interpolation_steps
)
interpolated_positive_pooled_embeddings = slerp(
    interpolated_positive_pooled_embeddings_12,
    interpolated_positive_pooled_embeddings_34,
    interpolation_steps,
)
interpolated_positive_pooled_embeddings = ops.reshape(
    interpolated_positive_pooled_embeddings,
    (batches, batch_size, positive_pooled_embeddings_shape[-1]),
)
negative_embeddings = ops.tile(encoding_1[1], (batch_size, 1, 1))
negative_pooled_embeddings = ops.tile(encoding_1[3], (batch_size, 1))

latents = random.normal((1, height // 8, width // 8, 16), seed=42)
latents = ops.tile(latents, (batch_size, 1, 1, 1))

images = []
progbar = keras.utils.Progbar(batches)
for i in range(batches):
    images.append(
        generate_function(
            latents,
            (
                interpolated_positive_embeddings[i],
                negative_embeddings,
                interpolated_positive_pooled_embeddings[i],
                negative_pooled_embeddings,
            ),
            ops.convert_to_tensor(num_steps),
            ops.convert_to_tensor(guidance_scale),
        )
    )
    progbar.update(i + 1, finalize=i == batches - 1)

让我们以网格形式显示生成的图像,以便更容易理解。

def plot_grid(images, path, grid_size, scale=2):
    fig, axs = plt.subplots(
        grid_size, grid_size, figsize=(grid_size * scale, grid_size * scale)
    )
    fig.tight_layout()
    plt.subplots_adjust(wspace=0, hspace=0)
    plt.axis("off")
    for ax in axs.flat:
        ax.axis("off")

    for i in range(min(grid_size * grid_size, len(images))):
        ax = axs.flat[i]
        ax.imshow(images[i])
        ax.axis("off")

    for i in range(len(images), grid_size * grid_size):
        axs.flat[i].axis("off")
        axs.flat[i].remove()

    plt.savefig(
        fname=path,
        pad_inches=0,
        bbox_inches="tight",
        transparent=False,
        dpi=60,
    )


images = ops.convert_to_numpy(decode_to_images(images, height, width))
plot_grid(images, "4-way-interpolation.jpg", interpolation_steps)

png

我们还可以在插值时通过去掉 seed 参数来让扩散隐向量变化

images = []
progbar = keras.utils.Progbar(batches)
for i in range(batches):
    # Vary diffusion latents for each input.
    latents = random.normal((batch_size, height // 8, width // 8, 16))
    images.append(
        generate_function(
            latents,
            (
                interpolated_positive_embeddings[i],
                negative_embeddings,
                interpolated_positive_pooled_embeddings[i],
                negative_pooled_embeddings,
            ),
            ops.convert_to_tensor(num_steps),
            ops.convert_to_tensor(guidance_scale),
        )
    )
    progbar.update(i + 1, finalize=i == batches - 1)

images = ops.convert_to_numpy(decode_to_images(images, height, width))
plot_grid(images, "4-way-interpolation-varying-latent.jpg", interpolation_steps)

png

接下来——让我们去漫步吧!


围绕一个文本提示词漫步

我们的下一个实验是围绕一个特定提示词生成的点在隐空间流形中进行漫步。

walk_steps = 64
batch_size = 4
batches = walk_steps // batch_size
step_size = 0.01
prompt = "The eiffel tower in the style of starry night"
encoding = get_text_embeddings(prompt)

positive_embeddings = encoding[0]
positive_pooled_embeddings = encoding[2]
negative_embeddings = encoding[1]
negative_pooled_embeddings = encoding[3]

# The shape of `positive_embeddings`: (1, 154, 4096)
# The shape of `positive_pooled_embeddings`: (1, 2048)
positive_embeddings_delta = ops.ones_like(positive_embeddings) * step_size
positive_pooled_embeddings_delta = ops.ones_like(positive_pooled_embeddings) * step_size
positive_embeddings_shape = ops.shape(positive_embeddings)
positive_pooled_embeddings_shape = ops.shape(positive_pooled_embeddings)

walked_positive_embeddings = []
walked_positive_pooled_embeddings = []
for step_index in range(walk_steps):
    walked_positive_embeddings.append(positive_embeddings)
    walked_positive_pooled_embeddings.append(positive_pooled_embeddings)
    positive_embeddings += positive_embeddings_delta
    positive_pooled_embeddings += positive_pooled_embeddings_delta
walked_positive_embeddings = ops.stack(walked_positive_embeddings, axis=0)
walked_positive_pooled_embeddings = ops.stack(walked_positive_pooled_embeddings, axis=0)
walked_positive_embeddings = ops.reshape(
    walked_positive_embeddings,
    (
        batches,
        batch_size,
        positive_embeddings_shape[-2],
        positive_embeddings_shape[-1],
    ),
)
walked_positive_pooled_embeddings = ops.reshape(
    walked_positive_pooled_embeddings,
    (batches, batch_size, positive_pooled_embeddings_shape[-1]),
)
negative_embeddings = ops.tile(encoding_1[1], (batch_size, 1, 1))
negative_pooled_embeddings = ops.tile(encoding_1[3], (batch_size, 1))

latents = random.normal((1, height // 8, width // 8, 16), seed=42)
latents = ops.tile(latents, (batch_size, 1, 1, 1))

images = []
progbar = keras.utils.Progbar(batches)
for i in range(batches):
    images.append(
        generate_function(
            latents,
            (
                walked_positive_embeddings[i],
                negative_embeddings,
                walked_positive_pooled_embeddings[i],
                negative_pooled_embeddings,
            ),
            ops.convert_to_tensor(num_steps),
            ops.convert_to_tensor(guidance_scale),
        )
    )
    progbar.update(i + 1, finalize=i == batches - 1)

images = ops.convert_to_numpy(decode_to_images(images, height, width))
export_as_gif(
    "eiffel-tower-starry-night.gif",
    [Image.fromarray(image) for image in images],
    frames_per_second=2,
)

eiffel-tower-starry-night.gif

也许并不意外,离编码器的隐空间流形太远会产生不连贯的图像。你可以自己尝试设置提示词,并调整 step_size 来增加或减少漫步的幅度。请注意,当漫步幅度变大时,通常会进入产生极度噪点图像的区域。


在单个提示词的扩散隐空间中进行循环漫步

我们的最后一个实验是坚持使用一个提示词,并探索扩散模型可以从该提示词生成的各种图像。我们通过控制用于初始化扩散过程的噪声来实现这一点。

我们创建两个噪声分量,xy,并从 0 到 2π 进行漫步,将 x 分量的余弦和 y 分量的正弦相加来产生噪声。使用这种方法,漫步的终点会回到我们开始漫步时的相同噪声输入,因此我们得到一个“可循环”的结果!

walk_steps = 64
batch_size = 4
batches = walk_steps // batch_size
prompt = "An oil paintings of cows in a field next to a windmill in Holland"
encoding = get_text_embeddings(prompt)

walk_latent_x = random.normal((1, height // 8, width // 8, 16))
walk_latent_y = random.normal((1, height // 8, width // 8, 16))
walk_scale_x = ops.cos(ops.linspace(0.0, 2.0, walk_steps) * math.pi)
walk_scale_y = ops.sin(ops.linspace(0.0, 2.0, walk_steps) * math.pi)
latent_x = ops.tensordot(walk_scale_x, walk_latent_x, axes=0)
latent_y = ops.tensordot(walk_scale_y, walk_latent_y, axes=0)
latents = ops.add(latent_x, latent_y)
latents = ops.reshape(latents, (batches, batch_size, height // 8, width // 8, 16))

images = []
progbar = keras.utils.Progbar(batches)
for i in range(batches):
    images.append(
        generate_function(
            latents[i],
            (
                ops.tile(encoding[0], (batch_size, 1, 1)),
                ops.tile(encoding[1], (batch_size, 1, 1)),
                ops.tile(encoding[2], (batch_size, 1)),
                ops.tile(encoding[3], (batch_size, 1)),
            ),
            ops.convert_to_tensor(num_steps),
            ops.convert_to_tensor(guidance_scale),
        )
    )
    progbar.update(i + 1, finalize=i == batches - 1)

images = ops.convert_to_numpy(decode_to_images(images, height, width))
export_as_gif(
    "cows.gif",
    [Image.fromarray(image) for image in images],
    frames_per_second=4,
    no_rubber_band=True,
)

cows.gif

尝试使用你自己的提示词和不同的参数值进行实验!


结论

Stable Diffusion 3 不仅仅提供单一的文本到图像生成。探索文本编码器的隐空间流形和扩散模型的隐空间是体验此模型强大功能的两种有趣方式,而 KerasHub 让这一切变得轻松!