作者: A_K_Nain
创建时间 2022/11/30
最后修改时间 2022/12/07
描述:使用降噪扩散概率模型生成花朵图像。
在过去五年中,生成式建模取得了巨大的进步。像 VAE、GAN 和基于流的模型证明了在生成高质量内容(尤其是图像)方面的巨大成功。扩散模型是一种新型的生成式模型,已被证明优于以前的方法。
扩散模型受非平衡热力学启发,它们通过降噪来学习生成。通过降噪学习包括两个过程,每个过程都是一个马尔可夫链。它们是
前向过程:在前向过程中,我们在一系列时间步长 (t1, t2, ..., tn )
中缓慢地向数据添加随机噪声。当前时间步长的样本从高斯分布中抽取,其中分布的均值以之前时间步长的样本为条件,而分布的方差遵循固定的时间表。在前向过程结束时,样本最终将具有纯噪声分布。
反向过程:在反向过程中,我们尝试在每个时间步长撤消添加的噪声。我们从纯噪声分布(前向过程的最后一步)开始,并尝试在反向方向 (tn, tn-1, ..., t1)
上对样本进行降噪。
我们在本代码示例中实现了 降噪扩散概率模型 论文,简称 DDPM。这是第一篇展示了使用扩散模型生成高质量图像的论文。作者证明了扩散模型的特定参数化揭示了训练期间跨多个噪声级别进行降噪评分匹配的等价性,以及采样期间进行退火朗之万动力学的等价性,这产生了最佳质量的结果。
这篇论文复制了扩散过程中涉及的两个马尔可夫链(前向过程和反向过程),但针对图像。前向过程是固定的,并根据论文中用 beta 表示的固定方差时间表逐渐向图像添加高斯噪声。这就是图像情况下的扩散过程:(图像 -> 噪声::噪声 -> 图像)
这篇论文描述了两种算法,一种用于训练模型,另一种用于从训练后的模型中采样。训练通过优化负对数似然的常用变分界来执行。目标函数进一步简化,网络被视为噪声预测网络。一旦优化完成,我们就可以从网络中采样,以从噪声样本中生成新图像。以下是论文中介绍的两种算法的概述
注意:DDPM 只是实现扩散模型的一种方法。此外,DDPM 中的采样算法复制了完整的马尔可夫链。因此,与 GAN 等其他生成模型相比,它在生成新样本方面速度较慢。为了解决这个问题,已经做出了很多研究努力。一个这样的例子是降噪扩散隐式模型,简称 DDIM,其中作者用非马尔可夫过程代替了马尔可夫链以更快地采样。您可以在 此处 找到 DDIM 的代码示例
实现 DDPM 模型很简单。我们定义一个模型,它接受两个输入:图像和随机采样的时间步长。在每个训练步骤中,我们执行以下操作来训练我们的模型
鉴于我们的模型知道如何在给定时间步长对噪声样本进行降噪,我们可以利用这个想法从纯噪声分布开始生成新样本。
import math
import numpy as np
import matplotlib.pyplot as plt
# Requires TensorFlow >=2.11 for the GroupNormalization layer.
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_datasets as tfds
batch_size = 32
num_epochs = 1 # Just for the sake of demonstration
total_timesteps = 1000
norm_groups = 8 # Number of groups used in GroupNormalization layer
learning_rate = 2e-4
img_size = 64
img_channels = 3
clip_min = -1.0
clip_max = 1.0
first_conv_channels = 64
channel_multiplier = [1, 2, 4, 8]
widths = [first_conv_channels * mult for mult in channel_multiplier]
has_attention = [False, False, True, True]
num_res_blocks = 2 # Number of residual blocks
dataset_name = "oxford_flowers102"
splits = ["train"]
我们使用 Oxford Flowers 102 数据集来生成花朵图像。在预处理方面,我们使用中心裁剪将图像调整为所需的图像大小,并将像素值范围调整为 [-1.0, 1.0]
。这与 DDPM 论文 作者应用的像素值范围一致。为了增强训练数据,我们随机左右翻转图像。
# Load the dataset
(ds,) = tfds.load(dataset_name, split=splits, with_info=False, shuffle_files=True)
def augment(img):
"""Flips an image left/right randomly."""
return tf.image.random_flip_left_right(img)
def resize_and_rescale(img, size):
"""Resize the image to the desired size first and then
rescale the pixel values in the range [-1.0, 1.0].
Args:
img: Image tensor
size: Desired image size for resizing
Returns:
Resized and rescaled image tensor
"""
height = tf.shape(img)[0]
width = tf.shape(img)[1]
crop_size = tf.minimum(height, width)
img = tf.image.crop_to_bounding_box(
img,
(height - crop_size) // 2,
(width - crop_size) // 2,
crop_size,
crop_size,
)
# Resize
img = tf.cast(img, dtype=tf.float32)
img = tf.image.resize(img, size=size, antialias=True)
# Rescale the pixel values
img = img / 127.5 - 1.0
img = tf.clip_by_value(img, clip_min, clip_max)
return img
def train_preprocessing(x):
img = x["image"]
img = resize_and_rescale(img, size=(img_size, img_size))
img = augment(img)
return img
train_ds = (
ds.map(train_preprocessing, num_parallel_calls=tf.data.AUTOTUNE)
.batch(batch_size, drop_remainder=True)
.shuffle(batch_size * 2)
.prefetch(tf.data.AUTOTUNE)
)
我们将前向过程和反向过程定义为单独的实用程序。此实用程序中的大部分代码都来自原始实现,并进行了一些细微的修改。
class GaussianDiffusion:
"""Gaussian diffusion utility.
Args:
beta_start: Start value of the scheduled variance
beta_end: End value of the scheduled variance
timesteps: Number of time steps in the forward process
"""
def __init__(
self,
beta_start=1e-4,
beta_end=0.02,
timesteps=1000,
clip_min=-1.0,
clip_max=1.0,
):
self.beta_start = beta_start
self.beta_end = beta_end
self.timesteps = timesteps
self.clip_min = clip_min
self.clip_max = clip_max
# Define the linear variance schedule
self.betas = betas = np.linspace(
beta_start,
beta_end,
timesteps,
dtype=np.float64, # Using float64 for better precision
)
self.num_timesteps = int(timesteps)
alphas = 1.0 - betas
alphas_cumprod = np.cumprod(alphas, axis=0)
alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])
self.betas = tf.constant(betas, dtype=tf.float32)
self.alphas_cumprod = tf.constant(alphas_cumprod, dtype=tf.float32)
self.alphas_cumprod_prev = tf.constant(alphas_cumprod_prev, dtype=tf.float32)
# Calculations for diffusion q(x_t | x_{t-1}) and others
self.sqrt_alphas_cumprod = tf.constant(
np.sqrt(alphas_cumprod), dtype=tf.float32
)
self.sqrt_one_minus_alphas_cumprod = tf.constant(
np.sqrt(1.0 - alphas_cumprod), dtype=tf.float32
)
self.log_one_minus_alphas_cumprod = tf.constant(
np.log(1.0 - alphas_cumprod), dtype=tf.float32
)
self.sqrt_recip_alphas_cumprod = tf.constant(
np.sqrt(1.0 / alphas_cumprod), dtype=tf.float32
)
self.sqrt_recipm1_alphas_cumprod = tf.constant(
np.sqrt(1.0 / alphas_cumprod - 1), dtype=tf.float32
)
# Calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = (
betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
)
self.posterior_variance = tf.constant(posterior_variance, dtype=tf.float32)
# Log calculation clipped because the posterior variance is 0 at the beginning
# of the diffusion chain
self.posterior_log_variance_clipped = tf.constant(
np.log(np.maximum(posterior_variance, 1e-20)), dtype=tf.float32
)
self.posterior_mean_coef1 = tf.constant(
betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod),
dtype=tf.float32,
)
self.posterior_mean_coef2 = tf.constant(
(1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod),
dtype=tf.float32,
)
def _extract(self, a, t, x_shape):
"""Extract some coefficients at specified timesteps,
then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes.
Args:
a: Tensor to extract from
t: Timestep for which the coefficients are to be extracted
x_shape: Shape of the current batched samples
"""
batch_size = x_shape[0]
out = tf.gather(a, t)
return tf.reshape(out, [batch_size, 1, 1, 1])
def q_mean_variance(self, x_start, t):
"""Extracts the mean, and the variance at current timestep.
Args:
x_start: Initial sample (before the first diffusion step)
t: Current timestep
"""
x_start_shape = tf.shape(x_start)
mean = self._extract(self.sqrt_alphas_cumprod, t, x_start_shape) * x_start
variance = self._extract(1.0 - self.alphas_cumprod, t, x_start_shape)
log_variance = self._extract(
self.log_one_minus_alphas_cumprod, t, x_start_shape
)
return mean, variance, log_variance
def q_sample(self, x_start, t, noise):
"""Diffuse the data.
Args:
x_start: Initial sample (before the first diffusion step)
t: Current timestep
noise: Gaussian noise to be added at the current timestep
Returns:
Diffused samples at timestep `t`
"""
x_start_shape = tf.shape(x_start)
return (
self._extract(self.sqrt_alphas_cumprod, t, tf.shape(x_start)) * x_start
+ self._extract(self.sqrt_one_minus_alphas_cumprod, t, x_start_shape)
* noise
)
def predict_start_from_noise(self, x_t, t, noise):
x_t_shape = tf.shape(x_t)
return (
self._extract(self.sqrt_recip_alphas_cumprod, t, x_t_shape) * x_t
- self._extract(self.sqrt_recipm1_alphas_cumprod, t, x_t_shape) * noise
)
def q_posterior(self, x_start, x_t, t):
"""Compute the mean and variance of the diffusion
posterior q(x_{t-1} | x_t, x_0).
Args:
x_start: Stating point(sample) for the posterior computation
x_t: Sample at timestep `t`
t: Current timestep
Returns:
Posterior mean and variance at current timestep
"""
x_t_shape = tf.shape(x_t)
posterior_mean = (
self._extract(self.posterior_mean_coef1, t, x_t_shape) * x_start
+ self._extract(self.posterior_mean_coef2, t, x_t_shape) * x_t
)
posterior_variance = self._extract(self.posterior_variance, t, x_t_shape)
posterior_log_variance_clipped = self._extract(
self.posterior_log_variance_clipped, t, x_t_shape
)
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def p_mean_variance(self, pred_noise, x, t, clip_denoised=True):
x_recon = self.predict_start_from_noise(x, t=t, noise=pred_noise)
if clip_denoised:
x_recon = tf.clip_by_value(x_recon, self.clip_min, self.clip_max)
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(
x_start=x_recon, x_t=x, t=t
)
return model_mean, posterior_variance, posterior_log_variance
def p_sample(self, pred_noise, x, t, clip_denoised=True):
"""Sample from the diffusion model.
Args:
pred_noise: Noise predicted by the diffusion model
x: Samples at a given timestep for which the noise was predicted
t: Current timestep
clip_denoised (bool): Whether to clip the predicted noise
within the specified range or not.
"""
model_mean, _, model_log_variance = self.p_mean_variance(
pred_noise, x=x, t=t, clip_denoised=clip_denoised
)
noise = tf.random.normal(shape=x.shape, dtype=x.dtype)
# No noise when t == 0
nonzero_mask = tf.reshape(
1 - tf.cast(tf.equal(t, 0), tf.float32), [tf.shape(x)[0], 1, 1, 1]
)
return model_mean + nonzero_mask * tf.exp(0.5 * model_log_variance) * noise
U-Net 最初是为语义分割而开发的,它是一种广泛用于实现扩散模型的架构,但有一些细微的修改
我们实现了与原始论文中使用的大多数内容相同。我们在整个网络中使用 swish
激活函数。我们使用方差缩放内核初始化器。
这里唯一的区别是用于 GroupNormalization
层的组数。对于花卉数据集,我们发现 groups=8
的值比默认值 groups=32
产生更好的结果。Dropout 是可选的,应该在过度拟合可能性很高的地方使用。在论文中,作者仅在 CIFAR10 上训练时使用 Dropout。
# Kernel initializer to use
def kernel_init(scale):
scale = max(scale, 1e-10)
return keras.initializers.VarianceScaling(
scale, mode="fan_avg", distribution="uniform"
)
class AttentionBlock(layers.Layer):
"""Applies self-attention.
Args:
units: Number of units in the dense layers
groups: Number of groups to be used for GroupNormalization layer
"""
def __init__(self, units, groups=8, **kwargs):
self.units = units
self.groups = groups
super().__init__(**kwargs)
self.norm = layers.GroupNormalization(groups=groups)
self.query = layers.Dense(units, kernel_initializer=kernel_init(1.0))
self.key = layers.Dense(units, kernel_initializer=kernel_init(1.0))
self.value = layers.Dense(units, kernel_initializer=kernel_init(1.0))
self.proj = layers.Dense(units, kernel_initializer=kernel_init(0.0))
def call(self, inputs):
batch_size = tf.shape(inputs)[0]
height = tf.shape(inputs)[1]
width = tf.shape(inputs)[2]
scale = tf.cast(self.units, tf.float32) ** (-0.5)
inputs = self.norm(inputs)
q = self.query(inputs)
k = self.key(inputs)
v = self.value(inputs)
attn_score = tf.einsum("bhwc, bHWc->bhwHW", q, k) * scale
attn_score = tf.reshape(attn_score, [batch_size, height, width, height * width])
attn_score = tf.nn.softmax(attn_score, -1)
attn_score = tf.reshape(attn_score, [batch_size, height, width, height, width])
proj = tf.einsum("bhwHW,bHWc->bhwc", attn_score, v)
proj = self.proj(proj)
return inputs + proj
class TimeEmbedding(layers.Layer):
def __init__(self, dim, **kwargs):
super().__init__(**kwargs)
self.dim = dim
self.half_dim = dim // 2
self.emb = math.log(10000) / (self.half_dim - 1)
self.emb = tf.exp(tf.range(self.half_dim, dtype=tf.float32) * -self.emb)
def call(self, inputs):
inputs = tf.cast(inputs, dtype=tf.float32)
emb = inputs[:, None] * self.emb[None, :]
emb = tf.concat([tf.sin(emb), tf.cos(emb)], axis=-1)
return emb
def ResidualBlock(width, groups=8, activation_fn=keras.activations.swish):
def apply(inputs):
x, t = inputs
input_width = x.shape[3]
if input_width == width:
residual = x
else:
residual = layers.Conv2D(
width, kernel_size=1, kernel_initializer=kernel_init(1.0)
)(x)
temb = activation_fn(t)
temb = layers.Dense(width, kernel_initializer=kernel_init(1.0))(temb)[
:, None, None, :
]
x = layers.GroupNormalization(groups=groups)(x)
x = activation_fn(x)
x = layers.Conv2D(
width, kernel_size=3, padding="same", kernel_initializer=kernel_init(1.0)
)(x)
x = layers.Add()([x, temb])
x = layers.GroupNormalization(groups=groups)(x)
x = activation_fn(x)
x = layers.Conv2D(
width, kernel_size=3, padding="same", kernel_initializer=kernel_init(0.0)
)(x)
x = layers.Add()([x, residual])
return x
return apply
def DownSample(width):
def apply(x):
x = layers.Conv2D(
width,
kernel_size=3,
strides=2,
padding="same",
kernel_initializer=kernel_init(1.0),
)(x)
return x
return apply
def UpSample(width, interpolation="nearest"):
def apply(x):
x = layers.UpSampling2D(size=2, interpolation=interpolation)(x)
x = layers.Conv2D(
width, kernel_size=3, padding="same", kernel_initializer=kernel_init(1.0)
)(x)
return x
return apply
def TimeMLP(units, activation_fn=keras.activations.swish):
def apply(inputs):
temb = layers.Dense(
units, activation=activation_fn, kernel_initializer=kernel_init(1.0)
)(inputs)
temb = layers.Dense(units, kernel_initializer=kernel_init(1.0))(temb)
return temb
return apply
def build_model(
img_size,
img_channels,
widths,
has_attention,
num_res_blocks=2,
norm_groups=8,
interpolation="nearest",
activation_fn=keras.activations.swish,
):
image_input = layers.Input(
shape=(img_size, img_size, img_channels), name="image_input"
)
time_input = keras.Input(shape=(), dtype=tf.int64, name="time_input")
x = layers.Conv2D(
first_conv_channels,
kernel_size=(3, 3),
padding="same",
kernel_initializer=kernel_init(1.0),
)(image_input)
temb = TimeEmbedding(dim=first_conv_channels * 4)(time_input)
temb = TimeMLP(units=first_conv_channels * 4, activation_fn=activation_fn)(temb)
skips = [x]
# DownBlock
for i in range(len(widths)):
for _ in range(num_res_blocks):
x = ResidualBlock(
widths[i], groups=norm_groups, activation_fn=activation_fn
)([x, temb])
if has_attention[i]:
x = AttentionBlock(widths[i], groups=norm_groups)(x)
skips.append(x)
if widths[i] != widths[-1]:
x = DownSample(widths[i])(x)
skips.append(x)
# MiddleBlock
x = ResidualBlock(widths[-1], groups=norm_groups, activation_fn=activation_fn)(
[x, temb]
)
x = AttentionBlock(widths[-1], groups=norm_groups)(x)
x = ResidualBlock(widths[-1], groups=norm_groups, activation_fn=activation_fn)(
[x, temb]
)
# UpBlock
for i in reversed(range(len(widths))):
for _ in range(num_res_blocks + 1):
x = layers.Concatenate(axis=-1)([x, skips.pop()])
x = ResidualBlock(
widths[i], groups=norm_groups, activation_fn=activation_fn
)([x, temb])
if has_attention[i]:
x = AttentionBlock(widths[i], groups=norm_groups)(x)
if i != 0:
x = UpSample(widths[i], interpolation=interpolation)(x)
# End block
x = layers.GroupNormalization(groups=norm_groups)(x)
x = activation_fn(x)
x = layers.Conv2D(3, (3, 3), padding="same", kernel_initializer=kernel_init(0.0))(x)
return keras.Model([image_input, time_input], x, name="unet")
我们遵循与论文中描述的相同的设置来训练扩散模型。我们使用 Adam
优化器,学习率为 2e-4
。我们使用 EMA 对模型参数进行衰减因子为 0.999 的衰减。我们将模型视为噪声预测网络,即在每个训练步骤中,我们将一批图像和对应的时间步长输入到我们的 UNet 中,网络输出噪声作为预测。
唯一的区别是我们没有使用内核 Inception 距离 (KID) 或 Frechet Inception 距离 (FID) 来评估训练期间生成的样本的质量。这是因为这两个指标都计算量很大,为了实现简明扼要,省略了这两个指标。
**注意:**我们使用均方误差作为损失函数,这与论文一致,并且在理论上是有意义的。然而,在实践中,使用平均绝对误差或 Huber 损失作为损失函数也很常见。
class DiffusionModel(keras.Model):
def __init__(self, network, ema_network, timesteps, gdf_util, ema=0.999):
super().__init__()
self.network = network
self.ema_network = ema_network
self.timesteps = timesteps
self.gdf_util = gdf_util
self.ema = ema
def train_step(self, images):
# 1. Get the batch size
batch_size = tf.shape(images)[0]
# 2. Sample timesteps uniformly
t = tf.random.uniform(
minval=0, maxval=self.timesteps, shape=(batch_size,), dtype=tf.int64
)
with tf.GradientTape() as tape:
# 3. Sample random noise to be added to the images in the batch
noise = tf.random.normal(shape=tf.shape(images), dtype=images.dtype)
# 4. Diffuse the images with noise
images_t = self.gdf_util.q_sample(images, t, noise)
# 5. Pass the diffused images and time steps to the network
pred_noise = self.network([images_t, t], training=True)
# 6. Calculate the loss
loss = self.loss(noise, pred_noise)
# 7. Get the gradients
gradients = tape.gradient(loss, self.network.trainable_weights)
# 8. Update the weights of the network
self.optimizer.apply_gradients(zip(gradients, self.network.trainable_weights))
# 9. Updates the weight values for the network with EMA weights
for weight, ema_weight in zip(self.network.weights, self.ema_network.weights):
ema_weight.assign(self.ema * ema_weight + (1 - self.ema) * weight)
# 10. Return loss values
return {"loss": loss}
def generate_images(self, num_images=16):
# 1. Randomly sample noise (starting point for reverse process)
samples = tf.random.normal(
shape=(num_images, img_size, img_size, img_channels), dtype=tf.float32
)
# 2. Sample from the model iteratively
for t in reversed(range(0, self.timesteps)):
tt = tf.cast(tf.fill(num_images, t), dtype=tf.int64)
pred_noise = self.ema_network.predict(
[samples, tt], verbose=0, batch_size=num_images
)
samples = self.gdf_util.p_sample(
pred_noise, samples, tt, clip_denoised=True
)
# 3. Return generated samples
return samples
def plot_images(
self, epoch=None, logs=None, num_rows=2, num_cols=8, figsize=(12, 5)
):
"""Utility to plot images using the diffusion model during training."""
generated_samples = self.generate_images(num_images=num_rows * num_cols)
generated_samples = (
tf.clip_by_value(generated_samples * 127.5 + 127.5, 0.0, 255.0)
.numpy()
.astype(np.uint8)
)
_, ax = plt.subplots(num_rows, num_cols, figsize=figsize)
for i, image in enumerate(generated_samples):
if num_rows == 1:
ax[i].imshow(image)
ax[i].axis("off")
else:
ax[i // num_cols, i % num_cols].imshow(image)
ax[i // num_cols, i % num_cols].axis("off")
plt.tight_layout()
plt.show()
# Build the unet model
network = build_model(
img_size=img_size,
img_channels=img_channels,
widths=widths,
has_attention=has_attention,
num_res_blocks=num_res_blocks,
norm_groups=norm_groups,
activation_fn=keras.activations.swish,
)
ema_network = build_model(
img_size=img_size,
img_channels=img_channels,
widths=widths,
has_attention=has_attention,
num_res_blocks=num_res_blocks,
norm_groups=norm_groups,
activation_fn=keras.activations.swish,
)
ema_network.set_weights(network.get_weights()) # Initially the weights are the same
# Get an instance of the Gaussian Diffusion utilities
gdf_util = GaussianDiffusion(timesteps=total_timesteps)
# Get the model
model = DiffusionModel(
network=network,
ema_network=ema_network,
gdf_util=gdf_util,
timesteps=total_timesteps,
)
# Compile the model
model.compile(
loss=keras.losses.MeanSquaredError(),
optimizer=keras.optimizers.Adam(learning_rate=learning_rate),
)
# Train the model
model.fit(
train_ds,
epochs=num_epochs,
batch_size=batch_size,
callbacks=[keras.callbacks.LambdaCallback(on_epoch_end=model.plot_images)],
)
31/31 [==============================] - ETA: 0s - loss: 0.7746
31/31 [==============================] - 194s 4s/step - loss: 0.7668
<keras.callbacks.History at 0x7fc9e86ce610>
我们在 V100 GPU 上对该模型训练了 800 个 epoch,每个 epoch 大约需要 8 秒才能完成。我们在这里加载这些权重,并从纯噪声开始生成一些样本。
!curl -LO https://github.com/AakashKumarNain/ddpms/releases/download/v3.0.0/checkpoints.zip
!unzip -qq checkpoints.zip
# Load the model weights
model.ema_network.load_weights("checkpoints/diffusion_model_checkpoint")
# Generate and plot some samples
model.plot_images(num_rows=4, num_cols=8)
% Total % Received % Xferd Average Speed Time Time Time Current
Dload Upload Total Spent Left Speed
0 0 0 0 0 0 0 0 --:--:-- --:--:-- --:--:-- 0
100 222M 100 222M 0 0 16.0M 0 0:00:13 0:00:13 --:--:-- 14.7M
我们成功地实现了并训练了与 DDPM 论文作者实现的完全相同的扩散模型。您可以在 此处 找到原始实现。
您可以尝试一些方法来改进模型
增加每个块的宽度。更大的模型可以在更少的 epoch 中学会降噪,尽管您可能需要注意过度拟合。
我们实现了方差时间表的线性时间表。您可以实现其他方案,例如余弦时间表,并比较性能。