开发者指南 / KerasCV / 使用 KerasCV 高性能生成图像

使用 KerasCV 高性能生成图像

作者: fcholletlukewooddivamgupta
创建日期 2022/09/25
上次修改日期 2022/09/25
描述:使用 KerasCV 的 Stable Diffusion 模型生成新图像。

在 Colab 中查看 GitHub 源代码


概述

在本指南中,我们将展示如何使用 KerasCV 实现的 stability.ai 文本到图像模型 Stable Diffusion,根据文本提示生成新颖的图像。

Stable Diffusion 是一种强大且开源的文本到图像生成模型。虽然存在多个开源实现,使您可以轻松地根据文本提示创建图像,但 KerasCV 提供了一些独特的优势。这些优势包括 XLA 编译混合精度 支持,两者结合实现了最先进的生成速度。

在本指南中,我们将探索 KerasCV 的 Stable Diffusion 实现,展示如何使用这些强大的性能提升,并探讨它们带来的性能优势。

注意:要在 torch 后端运行本指南,请在所有地方设置 jit_compile=False。Stable Diffusion 的 XLA 编译目前不适用于 torch。

首先,让我们安装一些依赖项并整理一些导入

!pip install -q --upgrade keras-cv
!pip install -q --upgrade keras  # Upgrade to Keras 3.
import time
import keras_cv
import keras
import matplotlib.pyplot as plt

简介

与大多数教程不同,我们不会先解释主题,然后展示如何实现它,在文本到图像生成中,展示比解释更容易理解。

体验 keras_cv.models.StableDiffusion() 的强大功能。

首先,我们构建一个模型

model = keras_cv.models.StableDiffusion(
    img_width=512, img_height=512, jit_compile=False
)
By using this model checkpoint, you acknowledge that its usage is subject to the terms of the CreativeML Open RAIL-M license at https://raw.githubusercontent.com/CompVis/stable-diffusion/main/LICENSE

接下来,我们提供一个提示

images = model.text_to_image("photograph of an astronaut riding a horse", batch_size=3)


def plot_images(images):
    plt.figure(figsize=(20, 20))
    for i in range(len(images)):
        ax = plt.subplot(1, len(images), i + 1)
        plt.imshow(images[i])
        plt.axis("off")


plot_images(images)
 50/50 ━━━━━━━━━━━━━━━━━━━━ 63s 211ms/step

png

非常棒!

但这并不是这个模型所能做的全部。让我们尝试一个更复杂的提示

images = model.text_to_image(
    "cute magical flying dog, fantasy art, "
    "golden color, high quality, highly detailed, elegant, sharp focus, "
    "concept art, character concepts, digital painting, mystery, adventure",
    batch_size=3,
)
plot_images(images)
 50/50 ━━━━━━━━━━━━━━━━━━━━ 10s 209ms/step

png

可能性实际上是无限的(或者至少扩展到 Stable Diffusion 潜在流形的边界)。


等等,这到底是如何运作的?

与您此时可能预期的不同,Stable Diffusion 实际上并非依赖于魔法。它是一种“潜在扩散模型”。让我们深入了解一下这意味着什么。

您可能熟悉超分辨率的概念:可以训练一个深度学习模型来去噪输入图像——从而将其转换为更高分辨率的版本。深度学习模型并非通过神奇地恢复从噪声、低分辨率输入中丢失的信息来做到这一点——而是,模型利用其训练数据分布来幻化出给定输入最可能的视觉细节。要详细了解超分辨率,您可以查看以下 Keras.io 教程

Super-resolution

当您将此概念推向极致时,您可能会开始思考——如果我们只是在纯噪声上运行这样的模型会怎样?然后,模型将“去噪噪声”并开始幻化出一张全新的图像。通过多次重复该过程,您可以将一小块噪声变成越来越清晰且高分辨率的人工图像。

这是潜在扩散的关键思想,于 2020 年在 使用潜在扩散模型进行高分辨率图像合成 中提出。要深入了解扩散,您可以查看 Keras.io 教程 去噪扩散隐式模型

Denoising diffusion

现在,要从潜在扩散转向文本到图像系统,您仍然需要添加一个关键功能:能够通过提示关键字控制生成的视觉内容。这是通过“条件化”完成的,这是一种经典的深度学习技术,它包括将表示一些文本的向量连接到噪声块,然后在 {图像:标题} 对的数据集上训练模型。

这产生了 Stable Diffusion 架构。Stable Diffusion 由三个部分组成

  • 文本编码器,它将您的提示转换为潜在向量。
  • 扩散模型,它反复“去噪”64x64 的潜在图像块。
  • 解码器,它将最终的 64x64 潜在块转换为更高分辨率的 512x512 图像。

首先,您的文本提示通过文本编码器投影到潜在向量空间,文本编码器只是一个预训练的、冻结的语言模型。然后,将该提示向量连接到随机生成的噪声块,该噪声块在连续的“步骤”中由扩散模型反复“去噪”(运行的步骤越多,图像越清晰和漂亮——默认值为 50 步)。

最后,将 64x64 的潜在图像发送到解码器以正确地将其渲染为高分辨率。

The Stable Diffusion architecture

总而言之,这是一个非常简单的系统——Keras 实现包含四个文件,总共不到 500 行代码

但是,当您在数十亿张图片及其标题上进行训练时,这个相对简单的系统开始看起来像魔法一样。正如费曼所说:“宇宙并不复杂,只是有很多东西!”


KerasCV 的优势

既然有几个 Stable Diffusion 的开源实现可用,为什么您应该使用 keras_cv.models.StableDiffusion

除了易于使用的 API 之外,KerasCV 的 Stable Diffusion 模型还具有一些强大的优势,包括

  • 图模式执行
  • 通过 jit_compile=True 进行 XLA 编译
  • 支持混合精度计算

当这些功能结合在一起时,KerasCV Stable Diffusion 模型的运行速度比朴素实现快几个数量级。本节将展示如何启用所有这些功能,以及使用它们带来的性能提升。

为了进行比较,我们运行了一些基准测试,将 HuggingFace diffusers 实现的 Stable Diffusion 与 KerasCV 实现的运行时间进行了比较。这两个实现都承担了生成 3 张图像的任务,每张图像的步数为 50。在本基准测试中,我们使用了 Tesla T4 GPU。

我们所有的基准测试都开源在 GitHub 上,可以在 Colab 上重新运行以复制结果。 下表显示了基准测试的结果

GPU 模型 运行时间
Tesla T4 KerasCV(预热启动) 28.97 秒
Tesla T4 diffusers(预热启动) 41.33 秒
Tesla V100 KerasCV(预热启动) 12.45
Tesla V100 diffusers(预热启动) 12.72

Tesla T4 上的执行时间提高了 30%!虽然在 V100 上的改进幅度要低得多,但我们通常期望基准测试的结果始终偏向于 KerasCV 在所有 NVIDIA GPU 上。

为了完整起见,报告了冷启动和预热启动的生成时间。冷启动执行时间包括模型创建和编译的一次性成本,因此在生产环境中可以忽略不计(在生产环境中,您将多次重用相同的模型实例)。无论如何,以下是冷启动数据

GPU 模型 运行时间
Tesla T4 KerasCV(冷启动) 83.47 秒
Tesla T4 diffusers(冷启动) 46.27 秒
Tesla V100 KerasCV(冷启动) 76.43
Tesla V100 diffusers(冷启动) 13.90

虽然运行本指南产生的运行时间结果可能会有所不同,但在我们的测试中,KerasCV 实现的 Stable Diffusion 明显快于其 PyTorch 对等物。这在很大程度上可能是由于 XLA 编译。

注意:每种优化的性能优势在不同的硬件设置之间可能会有很大的差异。

首先,让我们对未优化的模型进行基准测试

benchmark_result = []
start = time.time()
images = model.text_to_image(
    "A cute otter in a rainbow whirlpool holding shells, watercolor",
    batch_size=3,
)
end = time.time()
benchmark_result.append(["Standard", end - start])
plot_images(images)

print(f"Standard model: {(end - start):.2f} seconds")
keras.backend.clear_session()  # Clear session to preserve memory.
 50/50 ━━━━━━━━━━━━━━━━━━━━ 10s 209ms/step
Standard model: 10.57 seconds

png

混合精度

“混合精度”包括使用 float16 精度执行计算,同时以 float32 格式存储权重。这样做是为了利用这样一个事实,即在现代 NVIDIA GPU 上,float16 操作比其 float32 对等操作由速度快得多的内核支持。

在 Keras(以及 keras_cv.models.StableDiffusion)中启用混合精度计算就像调用以下命令一样简单

keras.mixed_precision.set_global_policy("mixed_float16")

就是这样。开箱即用 - 它可以正常工作。

model = keras_cv.models.StableDiffusion(jit_compile=False)

print("Compute dtype:", model.diffusion_model.compute_dtype)
print(
    "Variable dtype:",
    model.diffusion_model.variable_dtype,
)
By using this model checkpoint, you acknowledge that its usage is subject to the terms of the CreativeML Open RAIL-M license at https://raw.githubusercontent.com/CompVis/stable-diffusion/main/LICENSE
Compute dtype: float16
Variable dtype: float32

如您所见,上面构建的模型现在使用混合精度计算;利用 float16 操作的速度进行计算,同时以 float32 精度存储变量。

# Warm up model to run graph tracing before benchmarking.
model.text_to_image("warming up the model", batch_size=3)

start = time.time()
images = model.text_to_image(
    "a cute magical flying dog, fantasy art, "
    "golden color, high quality, highly detailed, elegant, sharp focus, "
    "concept art, character concepts, digital painting, mystery, adventure",
    batch_size=3,
)
end = time.time()
benchmark_result.append(["Mixed Precision", end - start])
plot_images(images)

print(f"Mixed precision model: {(end - start):.2f} seconds")
keras.backend.clear_session()
 50/50 ━━━━━━━━━━━━━━━━━━━━ 42s 132ms/step
 50/50 ━━━━━━━━━━━━━━━━━━━━ 6s 129ms/step
Mixed precision model: 6.65 seconds

png

XLA 编译

TensorFlow 和 JAX 内置了 XLA:加速线性代数 编译器。 keras_cv.models.StableDiffusion 原生支持 jit_compile 参数。将此参数设置为 True 将启用 XLA 编译,从而显著提高速度。

让我们在下面使用它

# Set back to the default for benchmarking purposes.
keras.mixed_precision.set_global_policy("float32")

model = keras_cv.models.StableDiffusion(jit_compile=True)
# Before we benchmark the model, we run inference once to make sure the TensorFlow
# graph has already been traced.
images = model.text_to_image("An avocado armchair", batch_size=3)
plot_images(images)
By using this model checkpoint, you acknowledge that its usage is subject to the terms of the CreativeML Open RAIL-M license at https://raw.githubusercontent.com/CompVis/stable-diffusion/main/LICENSE
 50/50 ━━━━━━━━━━━━━━━━━━━━ 48s 209ms/step

png

让我们对我们的 XLA 模型进行基准测试

start = time.time()
images = model.text_to_image(
    "A cute otter in a rainbow whirlpool holding shells, watercolor",
    batch_size=3,
)
end = time.time()
benchmark_result.append(["XLA", end - start])
plot_images(images)

print(f"With XLA: {(end - start):.2f} seconds")
keras.backend.clear_session()
 50/50 ━━━━━━━━━━━━━━━━━━━━ 11s 210ms/step
With XLA: 10.63 seconds

png

在 A100 GPU 上,我们获得了大约 2 倍的速度提升。太棒了!


综合起来

那么,如何构建性能最高的稳定扩散推理管道(截至 2022 年 9 月)呢?

只需这两行代码

keras.mixed_precision.set_global_policy("mixed_float16")
model = keras_cv.models.StableDiffusion(jit_compile=True)
By using this model checkpoint, you acknowledge that its usage is subject to the terms of the CreativeML Open RAIL-M license at https://raw.githubusercontent.com/CompVis/stable-diffusion/main/LICENSE

以及使用方法...

# Let's make sure to warm up the model
images = model.text_to_image(
    "Teddy bears conducting machine learning research",
    batch_size=3,
)
plot_images(images)
 50/50 ━━━━━━━━━━━━━━━━━━━━ 48s 131ms/step

png

到底有多快?让我们来找出答案!

start = time.time()
images = model.text_to_image(
    "A mysterious dark stranger visits the great pyramids of egypt, "
    "high quality, highly detailed, elegant, sharp focus, "
    "concept art, character concepts, digital painting",
    batch_size=3,
)
end = time.time()
benchmark_result.append(["XLA + Mixed Precision", end - start])
plot_images(images)

print(f"XLA + mixed precision: {(end - start):.2f} seconds")
 50/50 ━━━━━━━━━━━━━━━━━━━━ 6s 130ms/step
XLA + mixed precision: 6.66 seconds

png

让我们看看结果

print("{:<22} {:<22}".format("Model", "Runtime"))
for result in benchmark_result:
    name, runtime = result
    print("{:<22} {:<22}".format(name, runtime))
Model                  Runtime               
Standard               10.572920799255371    
Mixed Precision        6.651048421859741     
XLA                    10.632121562957764    
XLA + Mixed Precision  6.659237861633301     

我们的完全优化的模型仅用四秒钟就在 A100 GPU 上根据文本提示生成了三张新颖的图像。


结论

KerasCV 提供了最先进的 Stable Diffusion 实现——并且通过使用 XLA 和混合精度,它提供了截至 2022 年 9 月可用的最快 Stable Diffusion 管道。

通常,在 keras.io 教程的结尾,我们会提供一些未来学习的方向。这一次,我们给你一个想法

使用自己的提示运行模型!绝对令人兴奋!

如果您有自己的 NVIDIA GPU 或 M1 MacBookPro,您也可以在本地机器上运行模型。(请注意,在 M1 MacBookPro 上运行时,不应启用混合精度,因为它尚未得到 Apple 的 Metal 运行时的良好支持。)