作者: fchollet,lukewood,divamgupta
创建日期 2022/09/25
上次修改日期 2022/09/25
描述:使用 KerasCV 的 Stable Diffusion 模型生成新图像。
在本指南中,我们将展示如何使用 KerasCV 实现的 stability.ai 文本到图像模型 Stable Diffusion,根据文本提示生成新颖的图像。
Stable Diffusion 是一种强大且开源的文本到图像生成模型。虽然存在多个开源实现,使您可以轻松地根据文本提示创建图像,但 KerasCV 提供了一些独特的优势。这些优势包括 XLA 编译 和 混合精度 支持,两者结合实现了最先进的生成速度。
在本指南中,我们将探索 KerasCV 的 Stable Diffusion 实现,展示如何使用这些强大的性能提升,并探讨它们带来的性能优势。
注意:要在 torch
后端运行本指南,请在所有地方设置 jit_compile=False
。Stable Diffusion 的 XLA 编译目前不适用于 torch。
首先,让我们安装一些依赖项并整理一些导入
!pip install -q --upgrade keras-cv
!pip install -q --upgrade keras # Upgrade to Keras 3.
import time
import keras_cv
import keras
import matplotlib.pyplot as plt
与大多数教程不同,我们不会先解释主题,然后展示如何实现它,在文本到图像生成中,展示比解释更容易理解。
体验 keras_cv.models.StableDiffusion()
的强大功能。
首先,我们构建一个模型
model = keras_cv.models.StableDiffusion(
img_width=512, img_height=512, jit_compile=False
)
By using this model checkpoint, you acknowledge that its usage is subject to the terms of the CreativeML Open RAIL-M license at https://raw.githubusercontent.com/CompVis/stable-diffusion/main/LICENSE
接下来,我们提供一个提示
images = model.text_to_image("photograph of an astronaut riding a horse", batch_size=3)
def plot_images(images):
plt.figure(figsize=(20, 20))
for i in range(len(images)):
ax = plt.subplot(1, len(images), i + 1)
plt.imshow(images[i])
plt.axis("off")
plot_images(images)
50/50 ━━━━━━━━━━━━━━━━━━━━ 63s 211ms/step
非常棒!
但这并不是这个模型所能做的全部。让我们尝试一个更复杂的提示
images = model.text_to_image(
"cute magical flying dog, fantasy art, "
"golden color, high quality, highly detailed, elegant, sharp focus, "
"concept art, character concepts, digital painting, mystery, adventure",
batch_size=3,
)
plot_images(images)
50/50 ━━━━━━━━━━━━━━━━━━━━ 10s 209ms/step
可能性实际上是无限的(或者至少扩展到 Stable Diffusion 潜在流形的边界)。
与您此时可能预期的不同,Stable Diffusion 实际上并非依赖于魔法。它是一种“潜在扩散模型”。让我们深入了解一下这意味着什么。
您可能熟悉超分辨率的概念:可以训练一个深度学习模型来去噪输入图像——从而将其转换为更高分辨率的版本。深度学习模型并非通过神奇地恢复从噪声、低分辨率输入中丢失的信息来做到这一点——而是,模型利用其训练数据分布来幻化出给定输入最可能的视觉细节。要详细了解超分辨率,您可以查看以下 Keras.io 教程
当您将此概念推向极致时,您可能会开始思考——如果我们只是在纯噪声上运行这样的模型会怎样?然后,模型将“去噪噪声”并开始幻化出一张全新的图像。通过多次重复该过程,您可以将一小块噪声变成越来越清晰且高分辨率的人工图像。
这是潜在扩散的关键思想,于 2020 年在 使用潜在扩散模型进行高分辨率图像合成 中提出。要深入了解扩散,您可以查看 Keras.io 教程 去噪扩散隐式模型。
现在,要从潜在扩散转向文本到图像系统,您仍然需要添加一个关键功能:能够通过提示关键字控制生成的视觉内容。这是通过“条件化”完成的,这是一种经典的深度学习技术,它包括将表示一些文本的向量连接到噪声块,然后在 {图像:标题} 对的数据集上训练模型。
这产生了 Stable Diffusion 架构。Stable Diffusion 由三个部分组成
首先,您的文本提示通过文本编码器投影到潜在向量空间,文本编码器只是一个预训练的、冻结的语言模型。然后,将该提示向量连接到随机生成的噪声块,该噪声块在连续的“步骤”中由扩散模型反复“去噪”(运行的步骤越多,图像越清晰和漂亮——默认值为 50 步)。
最后,将 64x64 的潜在图像发送到解码器以正确地将其渲染为高分辨率。
总而言之,这是一个非常简单的系统——Keras 实现包含四个文件,总共不到 500 行代码
但是,当您在数十亿张图片及其标题上进行训练时,这个相对简单的系统开始看起来像魔法一样。正如费曼所说:“宇宙并不复杂,只是有很多东西!”
既然有几个 Stable Diffusion 的开源实现可用,为什么您应该使用 keras_cv.models.StableDiffusion
?
除了易于使用的 API 之外,KerasCV 的 Stable Diffusion 模型还具有一些强大的优势,包括
jit_compile=True
进行 XLA 编译当这些功能结合在一起时,KerasCV Stable Diffusion 模型的运行速度比朴素实现快几个数量级。本节将展示如何启用所有这些功能,以及使用它们带来的性能提升。
为了进行比较,我们运行了一些基准测试,将 HuggingFace diffusers 实现的 Stable Diffusion 与 KerasCV 实现的运行时间进行了比较。这两个实现都承担了生成 3 张图像的任务,每张图像的步数为 50。在本基准测试中,我们使用了 Tesla T4 GPU。
我们所有的基准测试都开源在 GitHub 上,可以在 Colab 上重新运行以复制结果。 下表显示了基准测试的结果
GPU | 模型 | 运行时间 |
---|---|---|
Tesla T4 | KerasCV(预热启动) | 28.97 秒 |
Tesla T4 | diffusers(预热启动) | 41.33 秒 |
Tesla V100 | KerasCV(预热启动) | 12.45 |
Tesla V100 | diffusers(预热启动) | 12.72 |
Tesla T4 上的执行时间提高了 30%!虽然在 V100 上的改进幅度要低得多,但我们通常期望基准测试的结果始终偏向于 KerasCV 在所有 NVIDIA GPU 上。
为了完整起见,报告了冷启动和预热启动的生成时间。冷启动执行时间包括模型创建和编译的一次性成本,因此在生产环境中可以忽略不计(在生产环境中,您将多次重用相同的模型实例)。无论如何,以下是冷启动数据
GPU | 模型 | 运行时间 |
---|---|---|
Tesla T4 | KerasCV(冷启动) | 83.47 秒 |
Tesla T4 | diffusers(冷启动) | 46.27 秒 |
Tesla V100 | KerasCV(冷启动) | 76.43 |
Tesla V100 | diffusers(冷启动) | 13.90 |
虽然运行本指南产生的运行时间结果可能会有所不同,但在我们的测试中,KerasCV 实现的 Stable Diffusion 明显快于其 PyTorch 对等物。这在很大程度上可能是由于 XLA 编译。
注意:每种优化的性能优势在不同的硬件设置之间可能会有很大的差异。
首先,让我们对未优化的模型进行基准测试
benchmark_result = []
start = time.time()
images = model.text_to_image(
"A cute otter in a rainbow whirlpool holding shells, watercolor",
batch_size=3,
)
end = time.time()
benchmark_result.append(["Standard", end - start])
plot_images(images)
print(f"Standard model: {(end - start):.2f} seconds")
keras.backend.clear_session() # Clear session to preserve memory.
50/50 ━━━━━━━━━━━━━━━━━━━━ 10s 209ms/step
Standard model: 10.57 seconds
“混合精度”包括使用 float16
精度执行计算,同时以 float32
格式存储权重。这样做是为了利用这样一个事实,即在现代 NVIDIA GPU 上,float16
操作比其 float32
对等操作由速度快得多的内核支持。
在 Keras(以及 keras_cv.models.StableDiffusion
)中启用混合精度计算就像调用以下命令一样简单
keras.mixed_precision.set_global_policy("mixed_float16")
就是这样。开箱即用 - 它可以正常工作。
model = keras_cv.models.StableDiffusion(jit_compile=False)
print("Compute dtype:", model.diffusion_model.compute_dtype)
print(
"Variable dtype:",
model.diffusion_model.variable_dtype,
)
By using this model checkpoint, you acknowledge that its usage is subject to the terms of the CreativeML Open RAIL-M license at https://raw.githubusercontent.com/CompVis/stable-diffusion/main/LICENSE
Compute dtype: float16
Variable dtype: float32
如您所见,上面构建的模型现在使用混合精度计算;利用 float16
操作的速度进行计算,同时以 float32
精度存储变量。
# Warm up model to run graph tracing before benchmarking.
model.text_to_image("warming up the model", batch_size=3)
start = time.time()
images = model.text_to_image(
"a cute magical flying dog, fantasy art, "
"golden color, high quality, highly detailed, elegant, sharp focus, "
"concept art, character concepts, digital painting, mystery, adventure",
batch_size=3,
)
end = time.time()
benchmark_result.append(["Mixed Precision", end - start])
plot_images(images)
print(f"Mixed precision model: {(end - start):.2f} seconds")
keras.backend.clear_session()
50/50 ━━━━━━━━━━━━━━━━━━━━ 42s 132ms/step
50/50 ━━━━━━━━━━━━━━━━━━━━ 6s 129ms/step
Mixed precision model: 6.65 seconds
TensorFlow 和 JAX 内置了 XLA:加速线性代数 编译器。 keras_cv.models.StableDiffusion
原生支持 jit_compile
参数。将此参数设置为 True
将启用 XLA 编译,从而显著提高速度。
让我们在下面使用它
# Set back to the default for benchmarking purposes.
keras.mixed_precision.set_global_policy("float32")
model = keras_cv.models.StableDiffusion(jit_compile=True)
# Before we benchmark the model, we run inference once to make sure the TensorFlow
# graph has already been traced.
images = model.text_to_image("An avocado armchair", batch_size=3)
plot_images(images)
By using this model checkpoint, you acknowledge that its usage is subject to the terms of the CreativeML Open RAIL-M license at https://raw.githubusercontent.com/CompVis/stable-diffusion/main/LICENSE
50/50 ━━━━━━━━━━━━━━━━━━━━ 48s 209ms/step
让我们对我们的 XLA 模型进行基准测试
start = time.time()
images = model.text_to_image(
"A cute otter in a rainbow whirlpool holding shells, watercolor",
batch_size=3,
)
end = time.time()
benchmark_result.append(["XLA", end - start])
plot_images(images)
print(f"With XLA: {(end - start):.2f} seconds")
keras.backend.clear_session()
50/50 ━━━━━━━━━━━━━━━━━━━━ 11s 210ms/step
With XLA: 10.63 seconds
在 A100 GPU 上,我们获得了大约 2 倍的速度提升。太棒了!
那么,如何构建性能最高的稳定扩散推理管道(截至 2022 年 9 月)呢?
只需这两行代码
keras.mixed_precision.set_global_policy("mixed_float16")
model = keras_cv.models.StableDiffusion(jit_compile=True)
By using this model checkpoint, you acknowledge that its usage is subject to the terms of the CreativeML Open RAIL-M license at https://raw.githubusercontent.com/CompVis/stable-diffusion/main/LICENSE
以及使用方法...
# Let's make sure to warm up the model
images = model.text_to_image(
"Teddy bears conducting machine learning research",
batch_size=3,
)
plot_images(images)
50/50 ━━━━━━━━━━━━━━━━━━━━ 48s 131ms/step
到底有多快?让我们来找出答案!
start = time.time()
images = model.text_to_image(
"A mysterious dark stranger visits the great pyramids of egypt, "
"high quality, highly detailed, elegant, sharp focus, "
"concept art, character concepts, digital painting",
batch_size=3,
)
end = time.time()
benchmark_result.append(["XLA + Mixed Precision", end - start])
plot_images(images)
print(f"XLA + mixed precision: {(end - start):.2f} seconds")
50/50 ━━━━━━━━━━━━━━━━━━━━ 6s 130ms/step
XLA + mixed precision: 6.66 seconds
让我们看看结果
print("{:<22} {:<22}".format("Model", "Runtime"))
for result in benchmark_result:
name, runtime = result
print("{:<22} {:<22}".format(name, runtime))
Model Runtime
Standard 10.572920799255371
Mixed Precision 6.651048421859741
XLA 10.632121562957764
XLA + Mixed Precision 6.659237861633301
我们的完全优化的模型仅用四秒钟就在 A100 GPU 上根据文本提示生成了三张新颖的图像。
KerasCV 提供了最先进的 Stable Diffusion 实现——并且通过使用 XLA 和混合精度,它提供了截至 2022 年 9 月可用的最快 Stable Diffusion 管道。
通常,在 keras.io 教程的结尾,我们会提供一些未来学习的方向。这一次,我们给你一个想法
使用自己的提示运行模型!绝对令人兴奋!
如果您有自己的 NVIDIA GPU 或 M1 MacBookPro,您也可以在本地机器上运行模型。(请注意,在 M1 MacBookPro 上运行时,不应启用混合精度,因为它尚未得到 Apple 的 Metal 运行时的良好支持。)