代码示例 / 生成式深度学习 / 去噪扩散隐式模型

去噪扩散隐式模型

作者: András Béres
创建日期 2022/06/24
最后修改日期 2022/06/24
描述: 使用去噪扩散隐式模型生成花卉图像。

ⓘ 此示例使用 Keras 3

在 Colab 中查看 GitHub 源代码


引言

什么是扩散模型?

最近,去噪扩散模型,包括基于分数的生成模型,作为一类强大的生成模型越来越受欢迎,它们在图像合成质量上甚至可以媲美生成对抗网络 (GANs)。它们倾向于生成更多样化的样本,同时训练稳定且易于扩展。最近的大型扩散模型,如DALL-E 2Imagen,展示了令人难以置信的文本到图像生成能力。然而,它们的一个缺点是采样速度较慢,因为生成一张图像需要多次正向传播。

扩散是指将结构化信号(图像)一步步转化为噪声的过程。通过模拟扩散,我们可以从训练图像生成噪声图像,并训练神经网络尝试对它们进行去噪。使用训练好的网络,我们可以模拟与扩散相反的过程,即逆扩散,这是图像从噪声中显现的过程。

diffusion process gif

一句话总结:扩散模型被训练用于去噪噪声图像,并通过迭代去噪纯噪声来生成图像。

本示例的目标

此代码示例旨在实现一个最小但功能完备(包含生成质量指标)的扩散模型,对计算资源要求不高,性能合理。我的实现选择和超参数调整都考虑了这些目标。

鉴于目前扩散模型的文献数学上相当复杂,包含多种理论框架(分数匹配微分方程马尔可夫链),有时甚至有冲突的符号表示(参见附录 C.2),理解它们可能令人望而却步。在本示例中,我将这些模型视为学习将噪声图像分解为其图像和高斯噪声分量。

在本示例中,我努力将所有长数学表达式分解为易于理解的部分,并给所有变量赋予解释性的名称。我还包含了大量相关文献的链接,以帮助感兴趣的读者深入研究该主题,希望此代码示例能成为实践者学习扩散模型的一个良好起点。

在接下来的部分中,我们将实现去噪扩散隐式模型 (DDIMs)的连续时间版本,并采用确定性采样。


设置

import os

os.environ["KERAS_BACKEND"] = "tensorflow"

import math
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_datasets as tfds

import keras
from keras import layers
from keras import ops

超参数

# data
dataset_name = "oxford_flowers102"
dataset_repetitions = 5
num_epochs = 1  # train for at least 50 epochs for good results
image_size = 64
# KID = Kernel Inception Distance, see related section
kid_image_size = 75
kid_diffusion_steps = 5
plot_diffusion_steps = 20

# sampling
min_signal_rate = 0.02
max_signal_rate = 0.95

# architecture
embedding_dims = 32
embedding_max_frequency = 1000.0
widths = [32, 64, 96, 128]
block_depth = 2

# optimization
batch_size = 64
ema = 0.999
learning_rate = 1e-3
weight_decay = 1e-4

数据管道

我们将使用牛津花卉 102数据集来生成花卉图像,这是一个包含约 8,000 张图像的多样化自然数据集。遗憾的是,官方的数据集划分不平衡,大部分图像都在测试集中。我们使用Tensorflow Datasets 切片 API创建新的划分(80% 训练,20% 验证)。我们应用中心裁剪作为预处理,并多次重复数据集(原因在下一节给出)。

def preprocess_image(data):
    # center crop image
    height = ops.shape(data["image"])[0]
    width = ops.shape(data["image"])[1]
    crop_size = ops.minimum(height, width)
    image = tf.image.crop_to_bounding_box(
        data["image"],
        (height - crop_size) // 2,
        (width - crop_size) // 2,
        crop_size,
        crop_size,
    )

    # resize and clip
    # for image downsampling it is important to turn on antialiasing
    image = tf.image.resize(image, size=[image_size, image_size], antialias=True)
    return ops.clip(image / 255.0, 0.0, 1.0)


def prepare_dataset(split):
    # the validation dataset is shuffled as well, because data order matters
    # for the KID estimation
    return (
        tfds.load(dataset_name, split=split, shuffle_files=True)
        .map(preprocess_image, num_parallel_calls=tf.data.AUTOTUNE)
        .cache()
        .repeat(dataset_repetitions)
        .shuffle(10 * batch_size)
        .batch(batch_size, drop_remainder=True)
        .prefetch(buffer_size=tf.data.AUTOTUNE)
    )


# load dataset
train_dataset = prepare_dataset("train[:80%]+validation[:80%]+test[:80%]")
val_dataset = prepare_dataset("train[80%:]+validation[80%:]+test[80%:]")

核 Inception 距离

核 Inception 距离 (KID) 是一种图像质量指标,被提议用于替代流行的Frechet Inception 距离 (FID)。我更喜欢 KID 而非 FID,因为它实现更简单,可以按批次估计,计算量更轻。更多细节请参见此处

在本示例中,图像在 Inception 网络可能的最小分辨率(75x75 而不是 299x299)下进行评估,并且出于计算效率考虑,度量仅在验证集上测量。出于同样的原因,我们在评估时将采样步数限制为 5。

由于数据集相对较小,我们每个 epoch 会多次遍历训练集和验证集,因为 KID 估计有噪声且计算密集,所以我们希望在许多迭代后进行评估,但要进行多次迭代。

@keras.saving.register_keras_serializable()
class KID(keras.metrics.Metric):
    def __init__(self, name, **kwargs):
        super().__init__(name=name, **kwargs)

        # KID is estimated per batch and is averaged across batches
        self.kid_tracker = keras.metrics.Mean(name="kid_tracker")

        # a pretrained InceptionV3 is used without its classification layer
        # transform the pixel values to the 0-255 range, then use the same
        # preprocessing as during pretraining
        self.encoder = keras.Sequential(
            [
                keras.Input(shape=(image_size, image_size, 3)),
                layers.Rescaling(255.0),
                layers.Resizing(height=kid_image_size, width=kid_image_size),
                layers.Lambda(keras.applications.inception_v3.preprocess_input),
                keras.applications.InceptionV3(
                    include_top=False,
                    input_shape=(kid_image_size, kid_image_size, 3),
                    weights="imagenet",
                ),
                layers.GlobalAveragePooling2D(),
            ],
            name="inception_encoder",
        )

    def polynomial_kernel(self, features_1, features_2):
        feature_dimensions = ops.cast(ops.shape(features_1)[1], dtype="float32")
        return (
            features_1 @ ops.transpose(features_2) / feature_dimensions + 1.0
        ) ** 3.0

    def update_state(self, real_images, generated_images, sample_weight=None):
        real_features = self.encoder(real_images, training=False)
        generated_features = self.encoder(generated_images, training=False)

        # compute polynomial kernels using the two sets of features
        kernel_real = self.polynomial_kernel(real_features, real_features)
        kernel_generated = self.polynomial_kernel(
            generated_features, generated_features
        )
        kernel_cross = self.polynomial_kernel(real_features, generated_features)

        # estimate the squared maximum mean discrepancy using the average kernel values
        batch_size = real_features.shape[0]
        batch_size_f = ops.cast(batch_size, dtype="float32")
        mean_kernel_real = ops.sum(kernel_real * (1.0 - ops.eye(batch_size))) / (
            batch_size_f * (batch_size_f - 1.0)
        )
        mean_kernel_generated = ops.sum(
            kernel_generated * (1.0 - ops.eye(batch_size))
        ) / (batch_size_f * (batch_size_f - 1.0))
        mean_kernel_cross = ops.mean(kernel_cross)
        kid = mean_kernel_real + mean_kernel_generated - 2.0 * mean_kernel_cross

        # update the average KID estimate
        self.kid_tracker.update_state(kid)

    def result(self):
        return self.kid_tracker.result()

    def reset_state(self):
        self.kid_tracker.reset_state()

网络架构

在这里,我们指定用于去噪的神经网络架构。我们构建了一个输入和输出维度相同的 U-Net。U-Net 是一种流行的语义分割架构,其主要思想是它逐步下采样然后上采样输入图像,并在具有相同分辨率的层之间添加跳跃连接。这些有助于梯度流动并避免引入表示瓶颈,这与通常的自编码器不同。基于此,可以将扩散模型视为没有瓶颈的去噪自编码器

该网络有两个输入:噪声图像和其噪声分量的方差。后者是必需的,因为在不同噪声水平下去噪信号需要不同的操作。我们使用正弦嵌入来转换噪声方差,这类似于transformerNeRF中使用的位置编码。这有助于网络对噪声水平高度敏感,这对于良好的性能至关重要。我们使用Lambda 层实现正弦嵌入。

其他注意事项

  • 我们使用Keras 函数式 API构建网络,并使用闭包以一致的风格构建层块。
  • 扩散模型将扩散过程的时间步索引嵌入,而不是噪声方差,而基于分数的模型(表 1)通常使用噪声水平的某种函数。我更喜欢后者,这样我们可以在推理时更改采样调度,而无需重新训练网络。
  • 扩散模型将嵌入输入到每个卷积块中。为了简单起见,我们只在网络开始时输入它,根据我的经验,这几乎不会降低性能,因为跳跃连接和残差连接有助于信息在网络中正确传播。
  • 在文献中,通常在较低分辨率下使用注意力层以获得更好的全局一致性。为了简单起见,我省略了它。
  • 我们禁用了批量归一化层的可学习中心和尺度参数,因为后面的卷积层使得它们变得多余。
  • 我们习惯性地将最后一个卷积层的核初始化为全零,这使得网络在初始化后只预测零,这是其目标的平均值。这将改善训练开始时的行为,并使均方误差损失从恰好 1 开始。
@keras.saving.register_keras_serializable()
def sinusoidal_embedding(x):
    embedding_min_frequency = 1.0
    frequencies = ops.exp(
        ops.linspace(
            ops.log(embedding_min_frequency),
            ops.log(embedding_max_frequency),
            embedding_dims // 2,
        )
    )
    angular_speeds = ops.cast(2.0 * math.pi * frequencies, "float32")
    embeddings = ops.concatenate(
        [ops.sin(angular_speeds * x), ops.cos(angular_speeds * x)], axis=3
    )
    return embeddings


def ResidualBlock(width):
    def apply(x):
        input_width = x.shape[3]
        if input_width == width:
            residual = x
        else:
            residual = layers.Conv2D(width, kernel_size=1)(x)
        x = layers.BatchNormalization(center=False, scale=False)(x)
        x = layers.Conv2D(width, kernel_size=3, padding="same", activation="swish")(x)
        x = layers.Conv2D(width, kernel_size=3, padding="same")(x)
        x = layers.Add()([x, residual])
        return x

    return apply


def DownBlock(width, block_depth):
    def apply(x):
        x, skips = x
        for _ in range(block_depth):
            x = ResidualBlock(width)(x)
            skips.append(x)
        x = layers.AveragePooling2D(pool_size=2)(x)
        return x

    return apply


def UpBlock(width, block_depth):
    def apply(x):
        x, skips = x
        x = layers.UpSampling2D(size=2, interpolation="bilinear")(x)
        for _ in range(block_depth):
            x = layers.Concatenate()([x, skips.pop()])
            x = ResidualBlock(width)(x)
        return x

    return apply


def get_network(image_size, widths, block_depth):
    noisy_images = keras.Input(shape=(image_size, image_size, 3))
    noise_variances = keras.Input(shape=(1, 1, 1))

    e = layers.Lambda(sinusoidal_embedding, output_shape=(1, 1, 32))(noise_variances)
    e = layers.UpSampling2D(size=image_size, interpolation="nearest")(e)

    x = layers.Conv2D(widths[0], kernel_size=1)(noisy_images)
    x = layers.Concatenate()([x, e])

    skips = []
    for width in widths[:-1]:
        x = DownBlock(width, block_depth)([x, skips])

    for _ in range(block_depth):
        x = ResidualBlock(widths[-1])(x)

    for width in reversed(widths[:-1]):
        x = UpBlock(width, block_depth)([x, skips])

    x = layers.Conv2D(3, kernel_size=1, kernel_initializer="zeros")(x)

    return keras.Model([noisy_images, noise_variances], x, name="residual_unet")

这展示了函数式 API 的强大之处。请注意,我们如何在 80 行代码中构建了一个相对复杂的 U-Net,它包含跳跃连接、残差块、多个输入和正弦嵌入!


扩散模型

扩散调度

假设扩散过程从时间 = 0 开始,在时间 = 1 结束。这个变量称为扩散时间,可以是离散的(在扩散模型中常见)或连续的(在基于分数的模型中常见)。我选择后者,这样可以在推理时更改采样步数。

我们需要一个函数,它能在扩散过程的每个点上告诉我们与实际扩散时间对应的噪声图像的噪声水平和信号水平。这称为扩散调度(参见 diffusion_schedule())。

这个调度输出两个量:noise_ratesignal_rate(分别对应于 DDIM 论文中的 sqrt(1 - alpha) 和 sqrt(alpha))。我们通过将随机噪声和训练图像按其对应的速率加权并相加来生成噪声图像。

由于(标准正态)随机噪声和(归一化)图像都具有零均值和单位方差,噪声率和信号率可以解释为噪声图像中其分量的标准差,而它们的平方率可以解释为它们的方差(或信号处理意义上的功率)。速率将始终设置为其平方和为 1,这意味着噪声图像将始终具有单位方差,就像其未缩放的分量一样。

我们将使用余弦调度(第 3.2 节)的一个简化连续版本,这在文献中相当常用。这个调度是对称的,在扩散过程的开始和结束时变化缓慢,并且具有很好的几何解释,利用了单位圆的三角性质

diffusion schedule gif

训练过程

去噪扩散模型的训练过程(参见 train_step()denoise())如下:我们均匀采样随机扩散时间,并根据扩散时间对应的速率将训练图像与随机高斯噪声混合。然后,我们训练模型将噪声图像分离为其两个分量。

通常,神经网络被训练来预测未缩放的噪声分量,从中可以使用信号和噪声率计算预测的图像分量。理论上应使用像素级均方误差,但我建议使用均绝对误差(类似于实现),这在此数据集上产生了更好的结果。

采样(逆扩散)

采样时(参见 reverse_diffusion()),每一步我们都取噪声图像的先前估计,并使用我们的网络将其分离为图像和噪声。然后,我们使用下一步的信号和噪声率重新组合这些分量。

尽管在DDIMs 的公式 12中显示了类似的观点,但我认为上述对采样方程的解释并不广为人知。

本示例仅实现了 DDIM 的确定性采样过程,这对应于论文中的 eta = 0。也可以使用随机采样(在这种情况下,模型成为去噪扩散概率模型 (DDPM)),其中部分预测噪声被相同或更大数量的随机噪声替换(参见公式 16 及以下)。

随机采样可以在不重新训练网络的情况下使用(因为两个模型的训练方式相同),并且可以提高样本质量,但另一方面通常需要更多的采样步骤。

@keras.saving.register_keras_serializable()
class DiffusionModel(keras.Model):
    def __init__(self, image_size, widths, block_depth):
        super().__init__()

        self.normalizer = layers.Normalization()
        self.network = get_network(image_size, widths, block_depth)
        self.ema_network = keras.models.clone_model(self.network)

    def compile(self, **kwargs):
        super().compile(**kwargs)

        self.noise_loss_tracker = keras.metrics.Mean(name="n_loss")
        self.image_loss_tracker = keras.metrics.Mean(name="i_loss")
        self.kid = KID(name="kid")

    @property
    def metrics(self):
        return [self.noise_loss_tracker, self.image_loss_tracker, self.kid]

    def denormalize(self, images):
        # convert the pixel values back to 0-1 range
        images = self.normalizer.mean + images * self.normalizer.variance**0.5
        return ops.clip(images, 0.0, 1.0)

    def diffusion_schedule(self, diffusion_times):
        # diffusion times -> angles
        start_angle = ops.cast(ops.arccos(max_signal_rate), "float32")
        end_angle = ops.cast(ops.arccos(min_signal_rate), "float32")

        diffusion_angles = start_angle + diffusion_times * (end_angle - start_angle)

        # angles -> signal and noise rates
        signal_rates = ops.cos(diffusion_angles)
        noise_rates = ops.sin(diffusion_angles)
        # note that their squared sum is always: sin^2(x) + cos^2(x) = 1

        return noise_rates, signal_rates

    def denoise(self, noisy_images, noise_rates, signal_rates, training):
        # the exponential moving average weights are used at evaluation
        if training:
            network = self.network
        else:
            network = self.ema_network

        # predict noise component and calculate the image component using it
        pred_noises = network([noisy_images, noise_rates**2], training=training)
        pred_images = (noisy_images - noise_rates * pred_noises) / signal_rates

        return pred_noises, pred_images

    def reverse_diffusion(self, initial_noise, diffusion_steps):
        # reverse diffusion = sampling
        num_images = initial_noise.shape[0]
        step_size = 1.0 / diffusion_steps

        # important line:
        # at the first sampling step, the "noisy image" is pure noise
        # but its signal rate is assumed to be nonzero (min_signal_rate)
        next_noisy_images = initial_noise
        for step in range(diffusion_steps):
            noisy_images = next_noisy_images

            # separate the current noisy image to its components
            diffusion_times = ops.ones((num_images, 1, 1, 1)) - step * step_size
            noise_rates, signal_rates = self.diffusion_schedule(diffusion_times)
            pred_noises, pred_images = self.denoise(
                noisy_images, noise_rates, signal_rates, training=False
            )
            # network used in eval mode

            # remix the predicted components using the next signal and noise rates
            next_diffusion_times = diffusion_times - step_size
            next_noise_rates, next_signal_rates = self.diffusion_schedule(
                next_diffusion_times
            )
            next_noisy_images = (
                next_signal_rates * pred_images + next_noise_rates * pred_noises
            )
            # this new noisy image will be used in the next step

        return pred_images

    def generate(self, num_images, diffusion_steps):
        # noise -> images -> denormalized images
        initial_noise = keras.random.normal(
            shape=(num_images, image_size, image_size, 3)
        )
        generated_images = self.reverse_diffusion(initial_noise, diffusion_steps)
        generated_images = self.denormalize(generated_images)
        return generated_images

    def train_step(self, images):
        # normalize images to have standard deviation of 1, like the noises
        images = self.normalizer(images, training=True)
        noises = keras.random.normal(shape=(batch_size, image_size, image_size, 3))

        # sample uniform random diffusion times
        diffusion_times = keras.random.uniform(
            shape=(batch_size, 1, 1, 1), minval=0.0, maxval=1.0
        )
        noise_rates, signal_rates = self.diffusion_schedule(diffusion_times)
        # mix the images with noises accordingly
        noisy_images = signal_rates * images + noise_rates * noises

        with tf.GradientTape() as tape:
            # train the network to separate noisy images to their components
            pred_noises, pred_images = self.denoise(
                noisy_images, noise_rates, signal_rates, training=True
            )

            noise_loss = self.loss(noises, pred_noises)  # used for training
            image_loss = self.loss(images, pred_images)  # only used as metric

        gradients = tape.gradient(noise_loss, self.network.trainable_weights)
        self.optimizer.apply_gradients(zip(gradients, self.network.trainable_weights))

        self.noise_loss_tracker.update_state(noise_loss)
        self.image_loss_tracker.update_state(image_loss)

        # track the exponential moving averages of weights
        for weight, ema_weight in zip(self.network.weights, self.ema_network.weights):
            ema_weight.assign(ema * ema_weight + (1 - ema) * weight)

        # KID is not measured during the training phase for computational efficiency
        return {m.name: m.result() for m in self.metrics[:-1]}

    def test_step(self, images):
        # normalize images to have standard deviation of 1, like the noises
        images = self.normalizer(images, training=False)
        noises = keras.random.normal(shape=(batch_size, image_size, image_size, 3))

        # sample uniform random diffusion times
        diffusion_times = keras.random.uniform(
            shape=(batch_size, 1, 1, 1), minval=0.0, maxval=1.0
        )
        noise_rates, signal_rates = self.diffusion_schedule(diffusion_times)
        # mix the images with noises accordingly
        noisy_images = signal_rates * images + noise_rates * noises

        # use the network to separate noisy images to their components
        pred_noises, pred_images = self.denoise(
            noisy_images, noise_rates, signal_rates, training=False
        )

        noise_loss = self.loss(noises, pred_noises)
        image_loss = self.loss(images, pred_images)

        self.image_loss_tracker.update_state(image_loss)
        self.noise_loss_tracker.update_state(noise_loss)

        # measure KID between real and generated images
        # this is computationally demanding, kid_diffusion_steps has to be small
        images = self.denormalize(images)
        generated_images = self.generate(
            num_images=batch_size, diffusion_steps=kid_diffusion_steps
        )
        self.kid.update_state(images, generated_images)

        return {m.name: m.result() for m in self.metrics}

    def plot_images(self, epoch=None, logs=None, num_rows=3, num_cols=6):
        # plot random generated images for visual evaluation of generation quality
        generated_images = self.generate(
            num_images=num_rows * num_cols,
            diffusion_steps=plot_diffusion_steps,
        )

        plt.figure(figsize=(num_cols * 2.0, num_rows * 2.0))
        for row in range(num_rows):
            for col in range(num_cols):
                index = row * num_cols + col
                plt.subplot(num_rows, num_cols, index + 1)
                plt.imshow(generated_images[index])
                plt.axis("off")
        plt.tight_layout()
        plt.show()
        plt.close()

训练

# create and compile the model
model = DiffusionModel(image_size, widths, block_depth)
# below tensorflow 2.9:
# pip install tensorflow_addons
# import tensorflow_addons as tfa
# optimizer=tfa.optimizers.AdamW
model.compile(
    optimizer=keras.optimizers.AdamW(
        learning_rate=learning_rate, weight_decay=weight_decay
    ),
    loss=keras.losses.mean_absolute_error,
)
# pixelwise mean absolute error is used as loss

# save the best model based on the validation KID metric
checkpoint_path = "checkpoints/diffusion_model.weights.h5"
checkpoint_callback = keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_path,
    save_weights_only=True,
    monitor="val_kid",
    mode="min",
    save_best_only=True,
)

# calculate mean and variance of training dataset for normalization
model.normalizer.adapt(train_dataset)

# run training and plot generated images periodically
model.fit(
    train_dataset,
    epochs=num_epochs,
    validation_data=val_dataset,
    callbacks=[
        keras.callbacks.LambdaCallback(on_epoch_end=model.plot_images),
        checkpoint_callback,
    ],
)
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/inception_v3/inception_v3_weights_tf_dim_ordering_tf_kernels_notop.h5
 87910968/87910968 ━━━━━━━━━━━━━━━━━━━━ 0s 0us/step

 511/511 ━━━━━━━━━━━━━━━━━━━━ 0s 48ms/step - i_loss: 0.6896 - n_loss: 0.2961

png

 511/511 ━━━━━━━━━━━━━━━━━━━━ 110s 138ms/step - i_loss: 0.6891 - n_loss: 0.2959 - kid: 0.0000e+00 - val_i_loss: 2.5650 - val_kid: 2.0372 - val_n_loss: 0.7914

<keras.src.callbacks.history.History at 0x7f521b149870>

推理

# load the best model and generate images
model.load_weights(checkpoint_path)
model.plot_images()

png


结果

通过运行训练至少 50 个 epoch(在 T4 GPU 上需要 2 小时,在 A100 GPU 上需要 30 分钟),可以使用此代码示例获得高质量的图像生成。

一批图像在 80 个 epoch 训练过程中的演变(颜色伪影是由于 GIF 压缩造成的)

flowers training gif

使用 1 到 20 个采样步骤从相同初始噪声生成的图像

flowers sampling steps gif

初始噪声样本之间的插值(球形)

flowers interpolation gif

确定性采样过程(顶部是噪声图像,底部是预测图像,40 步)

flowers deterministic generation gif

随机采样过程(顶部是噪声图像,底部是预测图像,80 步)

flowers stochastic generation gif


经验教训

在准备此代码示例期间,我使用此仓库进行了大量实验。在本节中,我将我学到的经验和建议按主观重要性顺序列出。

算法技巧

  • 最小和最大信号率:我发现最小信号率是一个重要的超参数。将其设置得太低会使生成的图像过饱和,而设置得太高会使它们欠饱和。我建议仔细调整它。此外,将其设置为 0 将导致除以零错误。最大信号率可以设置为 1,但我发现将其设置得稍微低一些可以略微提高生成质量。
  • 损失函数:虽然大型模型倾向于使用均方误差 (MSE) 损失,但我建议在此数据集上使用均绝对误差 (MAE)。根据我的经验,MSE 损失会生成更多样化的样本(它似乎也更容易出现幻觉第 3 节),而 MAE 损失会导致更平滑的图像。我建议尝试两者。
  • 权重衰减:当模型放大时,我偶尔会遇到训练发散的情况,发现权重衰减有助于避免不稳定性,且性能损失很低。这就是为什么在本示例中我使用AdamW而不是Adam
  • 权重的指数移动平均:这有助于降低 KID 指标的方差,并有助于在训练期间平均短期变化。
  • 图像增强:虽然我在本示例中没有使用图像增强,但根据我的经验,在训练中添加水平翻转可以提高生成性能,而随机裁剪则不能。由于我们使用监督去噪损失,过拟合可能是一个问题,因此在小数据集上图像增强可能很重要。还应该注意不要使用泄漏的增强,这可以遵循此方法(第 5 节末尾)等方式完成。
  • 数据归一化:在文献中,图像的像素值通常转换为 -1 到 1 的范围。为了理论上的正确性,我将图像归一化使其具有零均值和单位方差,就像随机噪声一样。
  • 噪声水平输入:我选择将噪声方差输入到网络中,因为在我们的采样调度下它是对称的。也可以输入噪声率(性能相似)、信号率(性能较低),甚至对数信噪比(附录 B.1)(未尝试,因为其范围高度依赖于最小和最大信号率,并且需要相应调整最小嵌入频率)。
  • 梯度裁剪:对于大型模型,使用全局梯度裁剪(值为 1)有助于训练稳定性,但根据我的经验会显著降低性能。
  • 残差连接下采样:对于更深的模型(附录 B),将残差连接乘以 1/sqrt(2) 可能有用,但在我的情况下没有帮助。
  • 学习率:对我来说,Adam 优化器的默认学习率 1e-3 工作得很好,但文献(表 11-13)中更常见较低的学习率。

架构技巧

  • 正弦嵌入:在网络的噪声水平输入上使用正弦嵌入对于良好的性能至关重要。我建议将最小嵌入频率设置为此输入的范围的倒数,并且由于在此示例中我们使用噪声方差,因此可以始终保持为 1。最大嵌入频率控制网络对噪声方差的最小变化敏感度,嵌入维度设置嵌入中的频率分量数量。根据我的经验,性能对这些值不太敏感。
  • 跳跃连接:在网络架构中使用跳跃连接是绝对关键的,没有它们,模型将无法以良好的性能学习去噪。
  • 残差连接:根据我的经验,残差连接也显著提高了性能,但这可能是因为我们只将噪声水平嵌入输入到网络的第一层,而不是所有层。
  • 归一化:当放大模型时,我偶尔会遇到训练发散的情况,使用归一化层有助于缓解这个问题。在文献中,通常在网络中使用GroupNormalization(例如 8 组)或LayerNormalization,但我选择使用BatchNormalization,因为它在我的实验中提供了类似的好处,但计算量更轻。
  • 激活函数:激活函数的选择对生成质量的影响比我预期的要大。在我的实验中,使用非单调激活函数(如ReLU)优于单调激活函数,其中Swish表现最好(这也是Imagen 使用的,第 41 页)。
  • 注意力:如前所述,在文献中,通常在较低分辨率下使用注意力层以获得更好的全局一致性。为了简单起见,我省略了它们。
  • 上采样:网络中的双线性插值和最近邻上采样表现相似,但我没有尝试转置卷积

有关 GAN 的类似列表,请查看此 Keras 教程


接下来尝试什么?

如果您想深入研究该主题,我建议您查看我为准备此代码示例而创建的此仓库,它以类似风格实现了更广泛的功能,例如