作者: András Béres
创建日期 2021/10/28
最后修改日期 2021/10/28
描述: 使用 Caltech Birds 数据集从有限数据生成图像。
生成对抗网络 (GAN) 是一类流行的生成式深度学习模型,常用于图像生成。它们由一对相互对抗的神经网络组成,称为判别器和生成器。判别器的任务是区分真实图像和生成的(虚假)图像,而生成器网络则试图通过生成越来越逼真的图像来欺骗判别器。然而,如果生成器太容易或太难被欺骗,它可能无法为生成器提供有用的学习信号,因此训练 GAN 通常被认为是一项困难的任务。
StyleGAN2-ADA 的作者表明,判别器过拟合可能是 GAN 中的一个问题,尤其是在只有少量训练数据可用时。他们提出了自适应判别器增强来缓解这个问题。
然而,将数据增强应用于 GAN 并非易事。由于生成器是使用判别器的梯度更新的,如果生成的图像被增强,则增强管道必须是可微分的,并且还必须与 GPU 兼容以提高计算效率。幸运的是,Keras 图像增强层 满足这两个要求,因此非常适合此任务。
在生成模型中使用数据增强时可能遇到的一个困难是 “泄漏增强” (第 2.2 节) 的问题,即当模型生成已经增强的图像时。这将意味着它无法将增强与底层数据分布分开,这可能是由于使用了不可逆的数据转换引起的。例如,如果以相等的概率执行 0、90、180 或 270 度旋转,则不可能推断出图像的原始方向,并且此信息被破坏。
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow import keras
from tensorflow.keras import layers
# data
num_epochs = 10 # train for 400 epochs for good results
image_size = 64
# resolution of Kernel Inception Distance measurement, see related section
kid_image_size = 75
padding = 0.25
dataset_name = "caltech_birds2011"
# adaptive discriminator augmentation
max_translation = 0.125
max_rotation = 0.125
max_zoom = 0.25
target_accuracy = 0.85
integration_steps = 1000
# architecture
noise_size = 64
depth = 4
width = 128
leaky_relu_slope = 0.2
dropout_rate = 0.4
# optimization
batch_size = 128
learning_rate = 2e-4
beta_1 = 0.5 # not using the default value of 0.9 is important
ema = 0.99
在此示例中,我们将使用 Caltech Birds (2011) 数据集来生成鸟类图像,这是一个包含少于 6000 张训练图像的多元自然数据集。当使用如此少量的数据时,必须格外小心以尽可能保持高数据质量。在此示例中,我们使用提供的鸟类边界框来剪切它们,并尽可能保持其宽高比。
def round_to_int(float_value):
return tf.cast(tf.math.round(float_value), dtype=tf.int32)
def preprocess_image(data):
# unnormalize bounding box coordinates
height = tf.cast(tf.shape(data["image"])[0], dtype=tf.float32)
width = tf.cast(tf.shape(data["image"])[1], dtype=tf.float32)
bounding_box = data["bbox"] * tf.stack([height, width, height, width])
# calculate center and length of longer side, add padding
target_center_y = 0.5 * (bounding_box[0] + bounding_box[2])
target_center_x = 0.5 * (bounding_box[1] + bounding_box[3])
target_size = tf.maximum(
(1.0 + padding) * (bounding_box[2] - bounding_box[0]),
(1.0 + padding) * (bounding_box[3] - bounding_box[1]),
# modify crop size to fit into image
target_height = tf.reduce_min(
[target_size, 2.0 * target_center_y, 2.0 * (height - target_center_y)]
target_width = tf.reduce_min(
[target_size, 2.0 * target_center_x, 2.0 * (width - target_center_x)]
# crop image
image = tf.image.crop_to_bounding_box(
offset_height=round_to_int(target_center_y - 0.5 * target_height),
offset_width=round_to_int(target_center_x - 0.5 * target_width),
# resize and clip
# for image downsampling, area interpolation is the preferred method
image = tf.image.resize(
image, size=[image_size, image_size], method=tf.image.ResizeMethod.AREA
return tf.clip_by_value(image / 255.0, 0.0, 1.0)
def prepare_dataset(split):
# the validation dataset is shuffled as well, because data order matters
# for the KID calculation
return (
tfds.load(dataset_name, split=split, shuffle_files=True)
.map(preprocess_image, num_parallel_calls=tf.data.AUTOTUNE)
.shuffle(10 * batch_size)
.batch(batch_size, drop_remainder=True)
train_dataset = prepare_dataset("train")
val_dataset = prepare_dataset("test")
核初始距离 (KID) 被提出作为衡量图像生成质量的流行指标 Frechet 初始距离 (FID) 指标的替代品。这两个指标都衡量 InceptionV3 网络在 ImageNet 上预训练的表示空间中生成分布和训练分布的差异。
根据论文,提出 KID 是因为 FID 没有无偏估计量,当在较少的图像上测量时,其期望值较高。KID 更适合小型数据集,因为其期望值不依赖于测量的样本数量。根据我的经验,它在计算上也更轻,数值上更稳定,并且由于可以按批次方式进行估计,因此实现起来更简单。
在此示例中,图像在 Inception 网络的最小分辨率 (75x75 而不是 299x299) 下进行评估,并且仅在验证集上测量该指标以提高计算效率。
class KID(keras.metrics.Metric):
def __init__(self, name="kid", **kwargs):
super().__init__(name=name, **kwargs)
# KID is estimated per batch and is averaged across batches
self.kid_tracker = keras.metrics.Mean()
# a pretrained InceptionV3 is used without its classification layer
# transform the pixel values to the 0-255 range, then use the same
# preprocessing as during pretraining
self.encoder = keras.Sequential(
layers.InputLayer(input_shape=(image_size, image_size, 3)),
layers.Resizing(height=kid_image_size, width=kid_image_size),
input_shape=(kid_image_size, kid_image_size, 3),
def polynomial_kernel(self, features_1, features_2):
feature_dimensions = tf.cast(tf.shape(features_1)[1], dtype=tf.float32)
return (features_1 @ tf.transpose(features_2) / feature_dimensions + 1.0) ** 3.0
def update_state(self, real_images, generated_images, sample_weight=None):
real_features = self.encoder(real_images, training=False)
generated_features = self.encoder(generated_images, training=False)
# compute polynomial kernels using the two sets of features
kernel_real = self.polynomial_kernel(real_features, real_features)
kernel_generated = self.polynomial_kernel(
generated_features, generated_features
kernel_cross = self.polynomial_kernel(real_features, generated_features)
# estimate the squared maximum mean discrepancy using the average kernel values
batch_size = tf.shape(real_features)[0]
batch_size_f = tf.cast(batch_size, dtype=tf.float32)
mean_kernel_real = tf.reduce_sum(kernel_real * (1.0 - tf.eye(batch_size))) / (
batch_size_f * (batch_size_f - 1.0)
mean_kernel_generated = tf.reduce_sum(
kernel_generated * (1.0 - tf.eye(batch_size))
) / (batch_size_f * (batch_size_f - 1.0))
mean_kernel_cross = tf.reduce_mean(kernel_cross)
kid = mean_kernel_real + mean_kernel_generated - 2.0 * mean_kernel_cross
# update the average KID estimate
def result(self):
return self.kid_tracker.result()
def reset_state(self):
StyleGAN2-ADA 的作者提出在训练过程中自适应地更改增强概率。尽管在论文中解释的方式不同,但他们使用增强概率的 积分控制 来使判别器在真实图像上的准确率保持接近目标值。请注意,他们控制的变量实际上是判别器逻辑的平均符号(论文中的 r_t),它对应于 2 * 准确率 - 1。
:判别器在真实图像上的准确率的目标值。我建议从 80-90% 范围中选择其值。integration_steps
:将 100% 的准确率误差转换为 100% 的增强概率增加所需的更新步骤数。为了给出直观的理解,这定义了增强概率的变化速度。我建议将其设置为相对较高的值(在本例中为 1000),以便仅缓慢调整增强强度。此过程的主要动机是,目标准确率的最佳值在不同数据集大小之间是相似的(请参阅论文中的图 4 和图 5),因此不必重新调整,因为该过程会在需要时自动应用更强的数据增强。
# "hard sigmoid", useful for binary accuracy calculation from logits
def step(values):
# negative values -> 0.0, positive values -> 1.0
return 0.5 * (1.0 + tf.sign(values))
# augments images with a probability that is dynamically updated during training
class AdaptiveAugmenter(keras.Model):
def __init__(self):
# stores the current probability of an image being augmented
self.probability = tf.Variable(0.0)
# the corresponding augmentation names from the paper are shown above each layer
# the authors show (see figure 4), that the blitting and geometric augmentations
# are the most helpful in the low-data regime
self.augmenter = keras.Sequential(
layers.InputLayer(input_shape=(image_size, image_size, 3)),
# blitting/x-flip:
# blitting/integer translation:
# geometric/rotation:
# geometric/isotropic and anisotropic scaling:
height_factor=(-max_zoom, 0.0), width_factor=(-max_zoom, 0.0)
def call(self, images, training):
if training:
augmented_images = self.augmenter(images, training)
# during training either the original or the augmented images are selected
# based on self.probability
augmentation_values = tf.random.uniform(
shape=(batch_size, 1, 1, 1), minval=0.0, maxval=1.0
augmentation_bools = tf.math.less(augmentation_values, self.probability)
images = tf.where(augmentation_bools, augmented_images, images)
return images
def update(self, real_logits):
current_accuracy = tf.reduce_mean(step(real_logits))
# the augmentation probability is updated based on the discriminator's
# accuracy on real images
accuracy_error = current_accuracy - target_accuracy
self.probability + accuracy_error / integration_steps, 0.0, 1.0
GAN 对网络架构非常敏感,我在此示例中实现了 DCGAN 架构,因为它在训练过程中相对稳定,同时实现起来也很简单。我们在整个网络中使用恒定数量的过滤器,在生成器的最后一层使用 sigmoid 而不是 tanh,并使用默认初始化而不是随机正态作为进一步的简化。
作为一种良好的做法,我们禁用批归一化层中可学习的缩放参数,因为一方面,以下 relu + 卷积层使其冗余(如 文档 中所述)。但也是因为在 使用频谱归一化(第 4.1 节)时,应该根据理论禁用它,此处未使用频谱归一化,但在 GAN 中很常见。我们还禁用了全连接层和卷积层中的偏置,因为以下批归一化使其冗余。
# DCGAN generator
def get_generator():
noise_input = keras.Input(shape=(noise_size,))
x = layers.Dense(4 * 4 * width, use_bias=False)(noise_input)
x = layers.BatchNormalization(scale=False)(x)
x = layers.ReLU()(x)
x = layers.Reshape(target_shape=(4, 4, width))(x)
for _ in range(depth - 1):
x = layers.Conv2DTranspose(
width, kernel_size=4, strides=2, padding="same", use_bias=False,
x = layers.BatchNormalization(scale=False)(x)
x = layers.ReLU()(x)
image_output = layers.Conv2DTranspose(
3, kernel_size=4, strides=2, padding="same", activation="sigmoid",
return keras.Model(noise_input, image_output, name="generator")
# DCGAN discriminator
def get_discriminator():
image_input = keras.Input(shape=(image_size, image_size, 3))
x = image_input
for _ in range(depth):
x = layers.Conv2D(
width, kernel_size=4, strides=2, padding="same", use_bias=False,
x = layers.BatchNormalization(scale=False)(x)
x = layers.LeakyReLU(alpha=leaky_relu_slope)(x)
x = layers.Flatten()(x)
x = layers.Dropout(dropout_rate)(x)
output_score = layers.Dense(1)(x)
return keras.Model(image_input, output_score, name="discriminator")
class GAN_ADA(keras.Model):
def __init__(self):
self.augmenter = AdaptiveAugmenter()
self.generator = get_generator()
self.ema_generator = keras.models.clone_model(self.generator)
self.discriminator = get_discriminator()
def compile(self, generator_optimizer, discriminator_optimizer, **kwargs):
# separate optimizers for the two networks
self.generator_optimizer = generator_optimizer
self.discriminator_optimizer = discriminator_optimizer
self.generator_loss_tracker = keras.metrics.Mean(name="g_loss")
self.discriminator_loss_tracker = keras.metrics.Mean(name="d_loss")
self.real_accuracy = keras.metrics.BinaryAccuracy(name="real_acc")
self.generated_accuracy = keras.metrics.BinaryAccuracy(name="gen_acc")
self.augmentation_probability_tracker = keras.metrics.Mean(name="aug_p")
self.kid = KID()
def metrics(self):
return [
def generate(self, batch_size, training):
latent_samples = tf.random.normal(shape=(batch_size, noise_size))
# use ema_generator during inference
if training:
generated_images = self.generator(latent_samples, training)
generated_images = self.ema_generator(latent_samples, training)
return generated_images
def adversarial_loss(self, real_logits, generated_logits):
# this is usually called the non-saturating GAN loss
real_labels = tf.ones(shape=(batch_size, 1))
generated_labels = tf.zeros(shape=(batch_size, 1))
# the generator tries to produce images that the discriminator considers as real
generator_loss = keras.losses.binary_crossentropy(
real_labels, generated_logits, from_logits=True
# the discriminator tries to determine if images are real or generated
discriminator_loss = keras.losses.binary_crossentropy(
tf.concat([real_labels, generated_labels], axis=0),
tf.concat([real_logits, generated_logits], axis=0),
return tf.reduce_mean(generator_loss), tf.reduce_mean(discriminator_loss)
def train_step(self, real_images):
real_images = self.augmenter(real_images, training=True)
# use persistent gradient tape because gradients will be calculated twice
with tf.GradientTape(persistent=True) as tape:
generated_images = self.generate(batch_size, training=True)
# gradient is calculated through the image augmentation
generated_images = self.augmenter(generated_images, training=True)
# separate forward passes for the real and generated images, meaning
# that batch normalization is applied separately
real_logits = self.discriminator(real_images, training=True)
generated_logits = self.discriminator(generated_images, training=True)
generator_loss, discriminator_loss = self.adversarial_loss(
real_logits, generated_logits
# calculate gradients and update weights
generator_gradients = tape.gradient(
generator_loss, self.generator.trainable_weights
discriminator_gradients = tape.gradient(
discriminator_loss, self.discriminator.trainable_weights
zip(generator_gradients, self.generator.trainable_weights)
zip(discriminator_gradients, self.discriminator.trainable_weights)
# update the augmentation probability based on the discriminator's performance
self.real_accuracy.update_state(1.0, step(real_logits))
self.generated_accuracy.update_state(0.0, step(generated_logits))
# track the exponential moving average of the generator's weights to decrease
# variance in the generation quality
for weight, ema_weight in zip(
self.generator.weights, self.ema_generator.weights
ema_weight.assign(ema * ema_weight + (1 - ema) * weight)
# KID is not measured during the training phase for computational efficiency
return {m.name: m.result() for m in self.metrics[:-1]}
def test_step(self, real_images):
generated_images = self.generate(batch_size, training=False)
self.kid.update_state(real_images, generated_images)
# only KID is measured during the evaluation phase for computational efficiency
return {self.kid.name: self.kid.result()}
def plot_images(self, epoch=None, logs=None, num_rows=3, num_cols=6, interval=5):
# plot random generated images for visual evaluation of generation quality
if epoch is None or (epoch + 1) % interval == 0:
num_images = num_rows * num_cols
generated_images = self.generate(num_images, training=False)
plt.figure(figsize=(num_cols * 2.0, num_rows * 2.0))
for row in range(num_rows):
for col in range(num_cols):
index = row * num_cols + col
plt.subplot(num_rows, num_cols, index + 1)
可以从训练过程中的指标中看到,如果真实准确率(判别器在真实图像上的准确率)低于目标准确率,则会增加增强概率,反之亦然。根据我的经验,在健康的 GAN 训练过程中,判别器准确率应保持在 80-95% 范围内。低于该范围,判别器太弱,高于该范围,判别器太强。
请注意,我们跟踪生成器权重的指数移动平均值,并将其用于图像生成和 KID 评估。
# create and compile the model
model = GAN_ADA()
generator_optimizer=keras.optimizers.Adam(learning_rate, beta_1),
discriminator_optimizer=keras.optimizers.Adam(learning_rate, beta_1),
# save the best model based on the validation KID metric
checkpoint_path = "gan_model"
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
# run training and plot generated images periodically
Model: "generator"
Layer (type) Output Shape Param #
input_2 (InputLayer) [(None, 64)] 0
dense (Dense) (None, 2048) 131072
batch_normalization (BatchNo (None, 2048) 6144
re_lu (ReLU) (None, 2048) 0
reshape (Reshape) (None, 4, 4, 128) 0
conv2d_transpose (Conv2DTran (None, 8, 8, 128) 262144
batch_normalization_1 (Batch (None, 8, 8, 128) 384
re_lu_1 (ReLU) (None, 8, 8, 128) 0
conv2d_transpose_1 (Conv2DTr (None, 16, 16, 128) 262144
batch_normalization_2 (Batch (None, 16, 16, 128) 384
re_lu_2 (ReLU) (None, 16, 16, 128) 0
conv2d_transpose_2 (Conv2DTr (None, 32, 32, 128) 262144
batch_normalization_3 (Batch (None, 32, 32, 128) 384
re_lu_3 (ReLU) (None, 32, 32, 128) 0
conv2d_transpose_3 (Conv2DTr (None, 64, 64, 3) 6147
Total params: 930,947
Trainable params: 926,083
Non-trainable params: 4,864
Model: "discriminator"
Layer (type) Output Shape Param #
input_3 (InputLayer) [(None, 64, 64, 3)] 0
conv2d (Conv2D) (None, 32, 32, 128) 6144
batch_normalization_4 (Batch (None, 32, 32, 128) 384
leaky_re_lu (LeakyReLU) (None, 32, 32, 128) 0
conv2d_1 (Conv2D) (None, 16, 16, 128) 262144
batch_normalization_5 (Batch (None, 16, 16, 128) 384
leaky_re_lu_1 (LeakyReLU) (None, 16, 16, 128) 0
conv2d_2 (Conv2D) (None, 8, 8, 128) 262144
batch_normalization_6 (Batch (None, 8, 8, 128) 384
leaky_re_lu_2 (LeakyReLU) (None, 8, 8, 128) 0
conv2d_3 (Conv2D) (None, 4, 4, 128) 262144
batch_normalization_7 (Batch (None, 4, 4, 128) 384
leaky_re_lu_3 (LeakyReLU) (None, 4, 4, 128) 0
flatten (Flatten) (None, 2048) 0
dropout (Dropout) (None, 2048) 0
dense_1 (Dense) (None, 1) 2049
Total params: 796,161
Trainable params: 795,137
Non-trainable params: 1,024
Epoch 1/10
46/46 [==============================] - 36s 307ms/step - g_loss: 3.3293 - d_loss: 0.1576 - real_acc: 0.9387 - gen_acc: 0.9579 - aug_p: 0.0020 - val_kid: 9.0999
Epoch 2/10
46/46 [==============================] - 10s 215ms/step - g_loss: 4.9824 - d_loss: 0.0912 - real_acc: 0.9704 - gen_acc: 0.9798 - aug_p: 0.0077 - val_kid: 8.3523
Epoch 3/10
46/46 [==============================] - 10s 218ms/step - g_loss: 5.0587 - d_loss: 0.1248 - real_acc: 0.9530 - gen_acc: 0.9625 - aug_p: 0.0131 - val_kid: 6.8116
Epoch 4/10
46/46 [==============================] - 10s 221ms/step - g_loss: 4.2580 - d_loss: 0.1002 - real_acc: 0.9686 - gen_acc: 0.9740 - aug_p: 0.0179 - val_kid: 5.2327
Epoch 5/10
46/46 [==============================] - 10s 225ms/step - g_loss: 4.6022 - d_loss: 0.0847 - real_acc: 0.9655 - gen_acc: 0.9852 - aug_p: 0.0234 - val_kid: 3.9004
Epoch 6/10
46/46 [==============================] - 10s 224ms/step - g_loss: 4.9362 - d_loss: 0.0671 - real_acc: 0.9791 - gen_acc: 0.9895 - aug_p: 0.0291 - val_kid: 6.6020
Epoch 7/10
46/46 [==============================] - 10s 222ms/step - g_loss: 4.4272 - d_loss: 0.1184 - real_acc: 0.9570 - gen_acc: 0.9657 - aug_p: 0.0345 - val_kid: 3.3644
Epoch 8/10
46/46 [==============================] - 10s 220ms/step - g_loss: 4.5060 - d_loss: 0.1635 - real_acc: 0.9421 - gen_acc: 0.9594 - aug_p: 0.0392 - val_kid: 3.1381
Epoch 9/10
46/46 [==============================] - 10s 219ms/step - g_loss: 3.8264 - d_loss: 0.1667 - real_acc: 0.9383 - gen_acc: 0.9484 - aug_p: 0.0433 - val_kid: 2.9423
Epoch 10/10
46/46 [==============================] - 10s 219ms/step - g_loss: 3.4063 - d_loss: 0.1757 - real_acc: 0.9314 - gen_acc: 0.9475 - aug_p: 0.0473 - val_kid: 2.9112
# load the best model and generate images
通过运行 400 个 epoch 的训练(在 Colab 笔记本中需要 2-3 个小时),可以使用此代码示例获得高质量的图像生成。
随机一批图像在 400 个 epoch 训练中的演变(ema=0.999 用于动画平滑度):
我还建议尝试在其他数据集上进行训练,例如 CelebA。根据我的经验,无需更改任何超参数即可获得良好的结果(尽管可能不需要判别器增强)。
我这个例子的目标是在 GAN 的易于实现性和生成质量之间找到一个好的平衡点。在准备过程中,我使用 这个存储库进行了多次消融实验。
我建议查看 DCGAN 论文、这个 NeurIPS 演讲,以及这个 大规模 GAN 研究,了解其他人对此主题的看法。
其他与 GAN 相关的 Keras 代码示例
现代 GAN 架构路线
关于 GAN 的最新文献概述:演讲