作者: A_K_Nain
创建日期 2020/08/12
最后修改日期 2024/09/30
描述: CycleGAN 的实现。
CycleGAN 是一种旨在解决图像到图像转换问题的模型。图像到图像转换问题的目标是使用一组对齐的图像对来学习输入图像和输出图像之间的映射。然而,获得成对的示例并非总是可行的。CycleGAN 尝试学习这种映射,而无需成对的输入-输出图像,使用循环一致的对抗网络。
import os
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import keras
from keras import layers, ops
import tensorflow_datasets as tfds
tfds.disable_progress_bar()
autotune = tf.data.AUTOTUNE
os.environ["KERAS_BACKEND"] = "tensorflow"
在此示例中,我们将使用 马到斑马 数据集。
# Load the horse-zebra dataset using tensorflow-datasets.
dataset, _ = tfds.load(name="cycle_gan/horse2zebra", with_info=True, as_supervised=True)
train_horses, train_zebras = dataset["trainA"], dataset["trainB"]
test_horses, test_zebras = dataset["testA"], dataset["testB"]
# Define the standard image size.
orig_img_size = (286, 286)
# Size of the random crops to be used during training.
input_img_size = (256, 256, 3)
# Weights initializer for the layers.
kernel_init = keras.initializers.RandomNormal(mean=0.0, stddev=0.02)
# Gamma initializer for instance normalization.
gamma_init = keras.initializers.RandomNormal(mean=0.0, stddev=0.02)
buffer_size = 256
batch_size = 1
def normalize_img(img):
img = ops.cast(img, dtype=tf.float32)
# Map values in the range [-1, 1]
return (img / 127.5) - 1.0
def preprocess_train_image(img, label):
# Random flip
img = tf.image.random_flip_left_right(img)
# Resize to the original size first
img = ops.image.resize(img, [*orig_img_size])
# Random crop to 256X256
img = tf.image.random_crop(img, size=[*input_img_size])
# Normalize the pixel values in the range [-1, 1]
img = normalize_img(img)
return img
def preprocess_test_image(img, label):
# Only resizing and normalization for the test images.
img = ops.image.resize(img, [input_img_size[0], input_img_size[1]])
img = normalize_img(img)
return img
Dataset
对象# Apply the preprocessing operations to the training data
train_horses = (
train_horses.map(preprocess_train_image, num_parallel_calls=autotune)
.cache()
.shuffle(buffer_size)
.batch(batch_size)
)
train_zebras = (
train_zebras.map(preprocess_train_image, num_parallel_calls=autotune)
.cache()
.shuffle(buffer_size)
.batch(batch_size)
)
# Apply the preprocessing operations to the test data
test_horses = (
test_horses.map(preprocess_test_image, num_parallel_calls=autotune)
.cache()
.shuffle(buffer_size)
.batch(batch_size)
)
test_zebras = (
test_zebras.map(preprocess_test_image, num_parallel_calls=autotune)
.cache()
.shuffle(buffer_size)
.batch(batch_size)
)
_, ax = plt.subplots(4, 2, figsize=(10, 15))
for i, samples in enumerate(zip(train_horses.take(4), train_zebras.take(4))):
horse = (((samples[0][0] * 127.5) + 127.5).numpy()).astype(np.uint8)
zebra = (((samples[1][0] * 127.5) + 127.5).numpy()).astype(np.uint8)
ax[i, 0].imshow(horse)
ax[i, 1].imshow(zebra)
plt.show()
class ReflectionPadding2D(layers.Layer):
"""Implements Reflection Padding as a layer.
Args:
padding(tuple): Amount of padding for the
spatial dimensions.
Returns:
A padded tensor with the same type as the input tensor.
"""
def __init__(self, padding=(1, 1), **kwargs):
self.padding = tuple(padding)
super().__init__(**kwargs)
def call(self, input_tensor, mask=None):
padding_width, padding_height = self.padding
padding_tensor = [
[0, 0],
[padding_height, padding_height],
[padding_width, padding_width],
[0, 0],
]
return ops.pad(input_tensor, padding_tensor, mode="REFLECT")
def residual_block(
x,
activation,
kernel_initializer=kernel_init,
kernel_size=(3, 3),
strides=(1, 1),
padding="valid",
gamma_initializer=gamma_init,
use_bias=False,
):
dim = x.shape[-1]
input_tensor = x
x = ReflectionPadding2D()(input_tensor)
x = layers.Conv2D(
dim,
kernel_size,
strides=strides,
kernel_initializer=kernel_initializer,
padding=padding,
use_bias=use_bias,
)(x)
x = keras.layers.GroupNormalization(groups=1, gamma_initializer=gamma_initializer)(
x
)
x = activation(x)
x = ReflectionPadding2D()(x)
x = layers.Conv2D(
dim,
kernel_size,
strides=strides,
kernel_initializer=kernel_initializer,
padding=padding,
use_bias=use_bias,
)(x)
x = keras.layers.GroupNormalization(groups=1, gamma_initializer=gamma_initializer)(
x
)
x = layers.add([input_tensor, x])
return x
def downsample(
x,
filters,
activation,
kernel_initializer=kernel_init,
kernel_size=(3, 3),
strides=(2, 2),
padding="same",
gamma_initializer=gamma_init,
use_bias=False,
):
x = layers.Conv2D(
filters,
kernel_size,
strides=strides,
kernel_initializer=kernel_initializer,
padding=padding,
use_bias=use_bias,
)(x)
x = keras.layers.GroupNormalization(groups=1, gamma_initializer=gamma_initializer)(
x
)
if activation:
x = activation(x)
return x
def upsample(
x,
filters,
activation,
kernel_size=(3, 3),
strides=(2, 2),
padding="same",
kernel_initializer=kernel_init,
gamma_initializer=gamma_init,
use_bias=False,
):
x = layers.Conv2DTranspose(
filters,
kernel_size,
strides=strides,
padding=padding,
kernel_initializer=kernel_initializer,
use_bias=use_bias,
)(x)
x = keras.layers.GroupNormalization(groups=1, gamma_initializer=gamma_initializer)(
x
)
if activation:
x = activation(x)
return x
生成器由下采样块、九个残差块和上采样块组成。生成器的结构如下
c7s1-64 ==> Conv block with `relu` activation, filter size of 7
d128 ====|
|-> 2 downsampling blocks
d256 ====|
R256 ====|
R256 |
R256 |
R256 |
R256 |-> 9 residual blocks
R256 |
R256 |
R256 |
R256 ====|
u128 ====|
|-> 2 upsampling blocks
u64 ====|
c7s1-3 => Last conv block with `tanh` activation, filter size of 7.
def get_resnet_generator(
filters=64,
num_downsampling_blocks=2,
num_residual_blocks=9,
num_upsample_blocks=2,
gamma_initializer=gamma_init,
name=None,
):
img_input = layers.Input(shape=input_img_size, name=name + "_img_input")
x = ReflectionPadding2D(padding=(3, 3))(img_input)
x = layers.Conv2D(filters, (7, 7), kernel_initializer=kernel_init, use_bias=False)(
x
)
x = keras.layers.GroupNormalization(groups=1, gamma_initializer=gamma_initializer)(
x
)
x = layers.Activation("relu")(x)
# Downsampling
for _ in range(num_downsampling_blocks):
filters *= 2
x = downsample(x, filters=filters, activation=layers.Activation("relu"))
# Residual blocks
for _ in range(num_residual_blocks):
x = residual_block(x, activation=layers.Activation("relu"))
# Upsampling
for _ in range(num_upsample_blocks):
filters //= 2
x = upsample(x, filters, activation=layers.Activation("relu"))
# Final block
x = ReflectionPadding2D(padding=(3, 3))(x)
x = layers.Conv2D(3, (7, 7), padding="valid")(x)
x = layers.Activation("tanh")(x)
model = keras.models.Model(img_input, x, name=name)
return model
判别器实现以下架构:C64->C128->C256->C512
def get_discriminator(
filters=64, kernel_initializer=kernel_init, num_downsampling=3, name=None
):
img_input = layers.Input(shape=input_img_size, name=name + "_img_input")
x = layers.Conv2D(
filters,
(4, 4),
strides=(2, 2),
padding="same",
kernel_initializer=kernel_initializer,
)(img_input)
x = layers.LeakyReLU(0.2)(x)
num_filters = filters
for num_downsample_block in range(3):
num_filters *= 2
if num_downsample_block < 2:
x = downsample(
x,
filters=num_filters,
activation=layers.LeakyReLU(0.2),
kernel_size=(4, 4),
strides=(2, 2),
)
else:
x = downsample(
x,
filters=num_filters,
activation=layers.LeakyReLU(0.2),
kernel_size=(4, 4),
strides=(1, 1),
)
x = layers.Conv2D(
1, (4, 4), strides=(1, 1), padding="same", kernel_initializer=kernel_initializer
)(x)
model = keras.models.Model(inputs=img_input, outputs=x, name=name)
return model
# Get the generators
gen_G = get_resnet_generator(name="generator_G")
gen_F = get_resnet_generator(name="generator_F")
# Get the discriminators
disc_X = get_discriminator(name="discriminator_X")
disc_Y = get_discriminator(name="discriminator_Y")
我们将覆盖 Model
类的 train_step()
方法,以便通过 fit()
进行训练。
class CycleGan(keras.Model):
def __init__(
self,
generator_G,
generator_F,
discriminator_X,
discriminator_Y,
lambda_cycle=10.0,
lambda_identity=0.5,
):
super().__init__()
self.gen_G = generator_G
self.gen_F = generator_F
self.disc_X = discriminator_X
self.disc_Y = discriminator_Y
self.lambda_cycle = lambda_cycle
self.lambda_identity = lambda_identity
def call(self, inputs):
return (
self.disc_X(inputs),
self.disc_Y(inputs),
self.gen_G(inputs),
self.gen_F(inputs),
)
def compile(
self,
gen_G_optimizer,
gen_F_optimizer,
disc_X_optimizer,
disc_Y_optimizer,
gen_loss_fn,
disc_loss_fn,
):
super().compile()
self.gen_G_optimizer = gen_G_optimizer
self.gen_F_optimizer = gen_F_optimizer
self.disc_X_optimizer = disc_X_optimizer
self.disc_Y_optimizer = disc_Y_optimizer
self.generator_loss_fn = gen_loss_fn
self.discriminator_loss_fn = disc_loss_fn
self.cycle_loss_fn = keras.losses.MeanAbsoluteError()
self.identity_loss_fn = keras.losses.MeanAbsoluteError()
def train_step(self, batch_data):
# x is Horse and y is zebra
real_x, real_y = batch_data
# For CycleGAN, we need to calculate different
# kinds of losses for the generators and discriminators.
# We will perform the following steps here:
#
# 1. Pass real images through the generators and get the generated images
# 2. Pass the generated images back to the generators to check if we
# can predict the original image from the generated image.
# 3. Do an identity mapping of the real images using the generators.
# 4. Pass the generated images in 1) to the corresponding discriminators.
# 5. Calculate the generators total loss (adversarial + cycle + identity)
# 6. Calculate the discriminators loss
# 7. Update the weights of the generators
# 8. Update the weights of the discriminators
# 9. Return the losses in a dictionary
with tf.GradientTape(persistent=True) as tape:
# Horse to fake zebra
fake_y = self.gen_G(real_x, training=True)
# Zebra to fake horse -> y2x
fake_x = self.gen_F(real_y, training=True)
# Cycle (Horse to fake zebra to fake horse): x -> y -> x
cycled_x = self.gen_F(fake_y, training=True)
# Cycle (Zebra to fake horse to fake zebra) y -> x -> y
cycled_y = self.gen_G(fake_x, training=True)
# Identity mapping
same_x = self.gen_F(real_x, training=True)
same_y = self.gen_G(real_y, training=True)
# Discriminator output
disc_real_x = self.disc_X(real_x, training=True)
disc_fake_x = self.disc_X(fake_x, training=True)
disc_real_y = self.disc_Y(real_y, training=True)
disc_fake_y = self.disc_Y(fake_y, training=True)
# Generator adversarial loss
gen_G_loss = self.generator_loss_fn(disc_fake_y)
gen_F_loss = self.generator_loss_fn(disc_fake_x)
# Generator cycle loss
cycle_loss_G = self.cycle_loss_fn(real_y, cycled_y) * self.lambda_cycle
cycle_loss_F = self.cycle_loss_fn(real_x, cycled_x) * self.lambda_cycle
# Generator identity loss
id_loss_G = (
self.identity_loss_fn(real_y, same_y)
* self.lambda_cycle
* self.lambda_identity
)
id_loss_F = (
self.identity_loss_fn(real_x, same_x)
* self.lambda_cycle
* self.lambda_identity
)
# Total generator loss
total_loss_G = gen_G_loss + cycle_loss_G + id_loss_G
total_loss_F = gen_F_loss + cycle_loss_F + id_loss_F
# Discriminator loss
disc_X_loss = self.discriminator_loss_fn(disc_real_x, disc_fake_x)
disc_Y_loss = self.discriminator_loss_fn(disc_real_y, disc_fake_y)
# Get the gradients for the generators
grads_G = tape.gradient(total_loss_G, self.gen_G.trainable_variables)
grads_F = tape.gradient(total_loss_F, self.gen_F.trainable_variables)
# Get the gradients for the discriminators
disc_X_grads = tape.gradient(disc_X_loss, self.disc_X.trainable_variables)
disc_Y_grads = tape.gradient(disc_Y_loss, self.disc_Y.trainable_variables)
# Update the weights of the generators
self.gen_G_optimizer.apply_gradients(
zip(grads_G, self.gen_G.trainable_variables)
)
self.gen_F_optimizer.apply_gradients(
zip(grads_F, self.gen_F.trainable_variables)
)
# Update the weights of the discriminators
self.disc_X_optimizer.apply_gradients(
zip(disc_X_grads, self.disc_X.trainable_variables)
)
self.disc_Y_optimizer.apply_gradients(
zip(disc_Y_grads, self.disc_Y.trainable_variables)
)
return {
"G_loss": total_loss_G,
"F_loss": total_loss_F,
"D_X_loss": disc_X_loss,
"D_Y_loss": disc_Y_loss,
}
class GANMonitor(keras.callbacks.Callback):
"""A callback to generate and save images after each epoch"""
def __init__(self, num_img=4):
self.num_img = num_img
def on_epoch_end(self, epoch, logs=None):
_, ax = plt.subplots(4, 2, figsize=(12, 12))
for i, img in enumerate(test_horses.take(self.num_img)):
prediction = self.model.gen_G(img)[0].numpy()
prediction = (prediction * 127.5 + 127.5).astype(np.uint8)
img = (img[0] * 127.5 + 127.5).numpy().astype(np.uint8)
ax[i, 0].imshow(img)
ax[i, 1].imshow(prediction)
ax[i, 0].set_title("Input image")
ax[i, 1].set_title("Translated image")
ax[i, 0].axis("off")
ax[i, 1].axis("off")
prediction = keras.utils.array_to_img(prediction)
prediction.save(
"generated_img_{i}_{epoch}.png".format(i=i, epoch=epoch + 1)
)
plt.show()
plt.close()
# Loss function for evaluating adversarial loss
adv_loss_fn = keras.losses.MeanSquaredError()
# Define the loss function for the generators
def generator_loss_fn(fake):
fake_loss = adv_loss_fn(ops.ones_like(fake), fake)
return fake_loss
# Define the loss function for the discriminators
def discriminator_loss_fn(real, fake):
real_loss = adv_loss_fn(ops.ones_like(real), real)
fake_loss = adv_loss_fn(ops.zeros_like(fake), fake)
return (real_loss + fake_loss) * 0.5
# Create cycle gan model
cycle_gan_model = CycleGan(
generator_G=gen_G, generator_F=gen_F, discriminator_X=disc_X, discriminator_Y=disc_Y
)
# Compile the model
cycle_gan_model.compile(
gen_G_optimizer=keras.optimizers.Adam(learning_rate=2e-4, beta_1=0.5),
gen_F_optimizer=keras.optimizers.Adam(learning_rate=2e-4, beta_1=0.5),
disc_X_optimizer=keras.optimizers.Adam(learning_rate=2e-4, beta_1=0.5),
disc_Y_optimizer=keras.optimizers.Adam(learning_rate=2e-4, beta_1=0.5),
gen_loss_fn=generator_loss_fn,
disc_loss_fn=discriminator_loss_fn,
)
# Callbacks
plotter = GANMonitor()
checkpoint_filepath = "./model_checkpoints/cyclegan_checkpoints.weights.h5"
model_checkpoint_callback = keras.callbacks.ModelCheckpoint(
filepath=checkpoint_filepath, save_weights_only=True
)
# Here we will train the model for just one epoch as each epoch takes around
# 7 minutes on a single P100 backed machine.
cycle_gan_model.fit(
tf.data.Dataset.zip((train_horses, train_zebras)),
epochs=90,
callbacks=[plotter, model_checkpoint_callback],
)
测试模型的性能。
# Once the weights are loaded, we will take a few samples from the test data and check the model's performance.
# Load the checkpoints
cycle_gan_model.load_weights(checkpoint_filepath)
print("Weights loaded successfully")
_, ax = plt.subplots(4, 2, figsize=(10, 15))
for i, img in enumerate(test_horses.take(4)):
prediction = cycle_gan_model.gen_G(img, training=False)[0].numpy()
prediction = (prediction * 127.5 + 127.5).astype(np.uint8)
img = (img[0] * 127.5 + 127.5).numpy().astype(np.uint8)
ax[i, 0].imshow(img)
ax[i, 1].imshow(prediction)
ax[i, 0].set_title("Input image")
ax[i, 0].set_title("Input image")
ax[i, 1].set_title("Translated image")
ax[i, 0].axis("off")
ax[i, 1].axis("off")
prediction = keras.utils.array_to_img(prediction)
prediction.save("predicted_img_{i}.png".format(i=i))
plt.tight_layout()
plt.show()
Weights loaded successfully