作者: Soon-Yau Cheong
创建日期 2021/07/01
最后修改日期 2021/12/20
描述:StyleGAN 用于图像生成的实现。
StyleGAN 的核心思想是逐步提高生成的图像分辨率,并在生成过程中加入风格特征。这个 StyleGAN 实现基于书籍 使用 TensorFlow 进行动手图像生成。书中的代码来自 GitHub 仓库,已重构以利用自定义的 train_step()
,通过编译和分布式实现更快的训练时间。
pip install tensorflow_addons
import os
import numpy as np
import matplotlib.pyplot as plt
from functools import partial
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.models import Sequential
from tensorflow_addons.layers import InstanceNormalization
import gdown
from zipfile import ZipFile
在本例中,我们将使用 TensorFlow Datasets 中的 CelebA 进行训练。
def log2(x):
return int(np.log2(x))
# we use different batch size for different resolution, so larger image size
# could fit into GPU memory. The keys is image resolution in log2
batch_sizes = {2: 16, 3: 16, 4: 16, 5: 16, 6: 16, 7: 8, 8: 4, 9: 2, 10: 1}
# We adjust the train step accordingly
train_step_ratio = {k: batch_sizes[2] / v for k, v in batch_sizes.items()}
os.makedirs("celeba_gan")
url = "https://drive.google.com/uc?id=1O7m1010EJjLE5QxLZiM9Fpjs7Oj6e684"
output = "celeba_gan/data.zip"
gdown.download(url, output, quiet=True)
with ZipFile("celeba_gan/data.zip", "r") as zipobj:
zipobj.extractall("celeba_gan")
# Create a dataset from our folder, and rescale the images to the [0-1] range:
ds_train = keras.utils.image_dataset_from_directory(
"celeba_gan", label_mode=None, image_size=(64, 64), batch_size=32
)
def resize_image(res, image):
# only downsampling, so use nearest neighbor that is faster to run
image = tf.image.resize(
image, (res, res), method=tf.image.ResizeMethod.NEAREST_NEIGHBOR
)
image = tf.cast(image, tf.float32) / 127.5 - 1.0
return image
def create_dataloader(res):
batch_size = batch_sizes[log2(res)]
# NOTE: we unbatch the dataset so we can `batch()` it again with the `drop_remainder=True` option
# since the model only supports a single batch size
dl = ds_train.map(partial(resize_image, res), num_parallel_calls=tf.data.AUTOTUNE).unbatch()
dl = dl.shuffle(200).batch(batch_size, drop_remainder=True).prefetch(1).repeat()
return dl
def plot_images(images, log2_res, fname=""):
scales = {2: 0.5, 3: 1, 4: 2, 5: 3, 6: 4, 7: 5, 8: 6, 9: 7, 10: 8}
scale = scales[log2_res]
grid_col = min(images.shape[0], int(32 // scale))
grid_row = 1
f, axarr = plt.subplots(
grid_row, grid_col, figsize=(grid_col * scale, grid_row * scale)
)
for row in range(grid_row):
ax = axarr if grid_row == 1 else axarr[row]
for col in range(grid_col):
ax[col].imshow(images[row * grid_col + col])
ax[col].axis("off")
plt.show()
if fname:
f.savefig(fname)
以下是将用于构建 StyleGAN 模型的生成器和鉴别器的构建块。
def fade_in(alpha, a, b):
return alpha * a + (1.0 - alpha) * b
def wasserstein_loss(y_true, y_pred):
return -tf.reduce_mean(y_true * y_pred)
def pixel_norm(x, epsilon=1e-8):
return x / tf.math.sqrt(tf.reduce_mean(x ** 2, axis=-1, keepdims=True) + epsilon)
def minibatch_std(input_tensor, epsilon=1e-8):
n, h, w, c = tf.shape(input_tensor)
group_size = tf.minimum(4, n)
x = tf.reshape(input_tensor, [group_size, -1, h, w, c])
group_mean, group_var = tf.nn.moments(x, axes=(0), keepdims=False)
group_std = tf.sqrt(group_var + epsilon)
avg_std = tf.reduce_mean(group_std, axis=[1, 2, 3], keepdims=True)
x = tf.tile(avg_std, [group_size, h, w, 1])
return tf.concat([input_tensor, x], axis=-1)
class EqualizedConv(layers.Layer):
def __init__(self, out_channels, kernel=3, gain=2, **kwargs):
super().__init__(**kwargs)
self.kernel = kernel
self.out_channels = out_channels
self.gain = gain
self.pad = kernel != 1
def build(self, input_shape):
self.in_channels = input_shape[-1]
initializer = keras.initializers.RandomNormal(mean=0.0, stddev=1.0)
self.w = self.add_weight(
shape=[self.kernel, self.kernel, self.in_channels, self.out_channels],
initializer=initializer,
trainable=True,
name="kernel",
)
self.b = self.add_weight(
shape=(self.out_channels,), initializer="zeros", trainable=True, name="bias"
)
fan_in = self.kernel * self.kernel * self.in_channels
self.scale = tf.sqrt(self.gain / fan_in)
def call(self, inputs):
if self.pad:
x = tf.pad(inputs, [[0, 0], [1, 1], [1, 1], [0, 0]], mode="REFLECT")
else:
x = inputs
output = (
tf.nn.conv2d(x, self.scale * self.w, strides=1, padding="VALID") + self.b
)
return output
class EqualizedDense(layers.Layer):
def __init__(self, units, gain=2, learning_rate_multiplier=1, **kwargs):
super().__init__(**kwargs)
self.units = units
self.gain = gain
self.learning_rate_multiplier = learning_rate_multiplier
def build(self, input_shape):
self.in_channels = input_shape[-1]
initializer = keras.initializers.RandomNormal(
mean=0.0, stddev=1.0 / self.learning_rate_multiplier
)
self.w = self.add_weight(
shape=[self.in_channels, self.units],
initializer=initializer,
trainable=True,
name="kernel",
)
self.b = self.add_weight(
shape=(self.units,), initializer="zeros", trainable=True, name="bias"
)
fan_in = self.in_channels
self.scale = tf.sqrt(self.gain / fan_in)
def call(self, inputs):
output = tf.add(tf.matmul(inputs, self.scale * self.w), self.b)
return output * self.learning_rate_multiplier
class AddNoise(layers.Layer):
def build(self, input_shape):
n, h, w, c = input_shape[0]
initializer = keras.initializers.RandomNormal(mean=0.0, stddev=1.0)
self.b = self.add_weight(
shape=[1, 1, 1, c], initializer=initializer, trainable=True, name="kernel"
)
def call(self, inputs):
x, noise = inputs
output = x + self.b * noise
return output
class AdaIN(layers.Layer):
def __init__(self, gain=1, **kwargs):
super().__init__(**kwargs)
self.gain = gain
def build(self, input_shapes):
x_shape = input_shapes[0]
w_shape = input_shapes[1]
self.w_channels = w_shape[-1]
self.x_channels = x_shape[-1]
self.dense_1 = EqualizedDense(self.x_channels, gain=1)
self.dense_2 = EqualizedDense(self.x_channels, gain=1)
def call(self, inputs):
x, w = inputs
ys = tf.reshape(self.dense_1(w), (-1, 1, 1, self.x_channels))
yb = tf.reshape(self.dense_2(w), (-1, 1, 1, self.x_channels))
return ys * x + yb
接下来,我们将构建以下内容:
对于生成器,我们在多个分辨率(例如 4x4、8x8,…,直至 1024x1024)构建生成器块。我们只在开头使用 4x4,并在训练过程中使用分辨率越来越大的块。鉴别器也是如此。
def Mapping(num_stages, input_shape=512):
z = layers.Input(shape=(input_shape))
w = pixel_norm(z)
for i in range(8):
w = EqualizedDense(512, learning_rate_multiplier=0.01)(w)
w = layers.LeakyReLU(0.2)(w)
w = tf.tile(tf.expand_dims(w, 1), (1, num_stages, 1))
return keras.Model(z, w, name="mapping")
class Generator:
def __init__(self, start_res_log2, target_res_log2):
self.start_res_log2 = start_res_log2
self.target_res_log2 = target_res_log2
self.num_stages = target_res_log2 - start_res_log2 + 1
# list of generator blocks at increasing resolution
self.g_blocks = []
# list of layers to convert g_block activation to RGB
self.to_rgb = []
# list of noise input of different resolutions into g_blocks
self.noise_inputs = []
# filter size to use at each stage, keys are log2(resolution)
self.filter_nums = {
0: 512,
1: 512,
2: 512, # 4x4
3: 512, # 8x8
4: 512, # 16x16
5: 512, # 32x32
6: 256, # 64x64
7: 128, # 128x128
8: 64, # 256x256
9: 32, # 512x512
10: 16,
} # 1024x1024
start_res = 2 ** start_res_log2
self.input_shape = (start_res, start_res, self.filter_nums[start_res_log2])
self.g_input = layers.Input(self.input_shape, name="generator_input")
for i in range(start_res_log2, target_res_log2 + 1):
filter_num = self.filter_nums[i]
res = 2 ** i
self.noise_inputs.append(
layers.Input(shape=(res, res, 1), name=f"noise_{res}x{res}")
)
to_rgb = Sequential(
[
layers.InputLayer(input_shape=(res, res, filter_num)),
EqualizedConv(3, 1, gain=1),
],
name=f"to_rgb_{res}x{res}",
)
self.to_rgb.append(to_rgb)
is_base = i == self.start_res_log2
if is_base:
input_shape = (res, res, self.filter_nums[i - 1])
else:
input_shape = (2 ** (i - 1), 2 ** (i - 1), self.filter_nums[i - 1])
g_block = self.build_block(
filter_num, res=res, input_shape=input_shape, is_base=is_base
)
self.g_blocks.append(g_block)
def build_block(self, filter_num, res, input_shape, is_base):
input_tensor = layers.Input(shape=input_shape, name=f"g_{res}")
noise = layers.Input(shape=(res, res, 1), name=f"noise_{res}")
w = layers.Input(shape=512)
x = input_tensor
if not is_base:
x = layers.UpSampling2D((2, 2))(x)
x = EqualizedConv(filter_num, 3)(x)
x = AddNoise()([x, noise])
x = layers.LeakyReLU(0.2)(x)
x = InstanceNormalization()(x)
x = AdaIN()([x, w])
x = EqualizedConv(filter_num, 3)(x)
x = AddNoise()([x, noise])
x = layers.LeakyReLU(0.2)(x)
x = InstanceNormalization()(x)
x = AdaIN()([x, w])
return keras.Model([input_tensor, w, noise], x, name=f"genblock_{res}x{res}")
def grow(self, res_log2):
res = 2 ** res_log2
num_stages = res_log2 - self.start_res_log2 + 1
w = layers.Input(shape=(self.num_stages, 512), name="w")
alpha = layers.Input(shape=(1), name="g_alpha")
x = self.g_blocks[0]([self.g_input, w[:, 0], self.noise_inputs[0]])
if num_stages == 1:
rgb = self.to_rgb[0](x)
else:
for i in range(1, num_stages - 1):
x = self.g_blocks[i]([x, w[:, i], self.noise_inputs[i]])
old_rgb = self.to_rgb[num_stages - 2](x)
old_rgb = layers.UpSampling2D((2, 2))(old_rgb)
i = num_stages - 1
x = self.g_blocks[i]([x, w[:, i], self.noise_inputs[i]])
new_rgb = self.to_rgb[i](x)
rgb = fade_in(alpha[0], new_rgb, old_rgb)
return keras.Model(
[self.g_input, w, self.noise_inputs, alpha],
rgb,
name=f"generator_{res}_x_{res}",
)
class Discriminator:
def __init__(self, start_res_log2, target_res_log2):
self.start_res_log2 = start_res_log2
self.target_res_log2 = target_res_log2
self.num_stages = target_res_log2 - start_res_log2 + 1
# filter size to use at each stage, keys are log2(resolution)
self.filter_nums = {
0: 512,
1: 512,
2: 512, # 4x4
3: 512, # 8x8
4: 512, # 16x16
5: 512, # 32x32
6: 256, # 64x64
7: 128, # 128x128
8: 64, # 256x256
9: 32, # 512x512
10: 16,
} # 1024x1024
# list of discriminator blocks at increasing resolution
self.d_blocks = []
# list of layers to convert RGB into activation for d_blocks inputs
self.from_rgb = []
for res_log2 in range(self.start_res_log2, self.target_res_log2 + 1):
res = 2 ** res_log2
filter_num = self.filter_nums[res_log2]
from_rgb = Sequential(
[
layers.InputLayer(
input_shape=(res, res, 3), name=f"from_rgb_input_{res}"
),
EqualizedConv(filter_num, 1),
layers.LeakyReLU(0.2),
],
name=f"from_rgb_{res}",
)
self.from_rgb.append(from_rgb)
input_shape = (res, res, filter_num)
if len(self.d_blocks) == 0:
d_block = self.build_base(filter_num, res)
else:
d_block = self.build_block(
filter_num, self.filter_nums[res_log2 - 1], res
)
self.d_blocks.append(d_block)
def build_base(self, filter_num, res):
input_tensor = layers.Input(shape=(res, res, filter_num), name=f"d_{res}")
x = minibatch_std(input_tensor)
x = EqualizedConv(filter_num, 3)(x)
x = layers.LeakyReLU(0.2)(x)
x = layers.Flatten()(x)
x = EqualizedDense(filter_num)(x)
x = layers.LeakyReLU(0.2)(x)
x = EqualizedDense(1)(x)
return keras.Model(input_tensor, x, name=f"d_{res}")
def build_block(self, filter_num_1, filter_num_2, res):
input_tensor = layers.Input(shape=(res, res, filter_num_1), name=f"d_{res}")
x = EqualizedConv(filter_num_1, 3)(input_tensor)
x = layers.LeakyReLU(0.2)(x)
x = EqualizedConv(filter_num_2)(x)
x = layers.LeakyReLU(0.2)(x)
x = layers.AveragePooling2D((2, 2))(x)
return keras.Model(input_tensor, x, name=f"d_{res}")
def grow(self, res_log2):
res = 2 ** res_log2
idx = res_log2 - self.start_res_log2
alpha = layers.Input(shape=(1), name="d_alpha")
input_image = layers.Input(shape=(res, res, 3), name="input_image")
x = self.from_rgb[idx](input_image)
x = self.d_blocks[idx](x)
if idx > 0:
idx -= 1
downsized_image = layers.AveragePooling2D((2, 2))(input_image)
y = self.from_rgb[idx](downsized_image)
x = fade_in(alpha[0], x, y)
for i in range(idx, -1, -1):
x = self.d_blocks[i](x)
return keras.Model([input_image, alpha], x, name=f"discriminator_{res}_x_{res}")
class StyleGAN(tf.keras.Model):
def __init__(self, z_dim=512, target_res=64, start_res=4):
super().__init__()
self.z_dim = z_dim
self.target_res_log2 = log2(target_res)
self.start_res_log2 = log2(start_res)
self.current_res_log2 = self.target_res_log2
self.num_stages = self.target_res_log2 - self.start_res_log2 + 1
self.alpha = tf.Variable(1.0, dtype=tf.float32, trainable=False, name="alpha")
self.mapping = Mapping(num_stages=self.num_stages)
self.d_builder = Discriminator(self.start_res_log2, self.target_res_log2)
self.g_builder = Generator(self.start_res_log2, self.target_res_log2)
self.g_input_shape = self.g_builder.input_shape
self.phase = None
self.train_step_counter = tf.Variable(0, dtype=tf.int32, trainable=False)
self.loss_weights = {"gradient_penalty": 10, "drift": 0.001}
def grow_model(self, res):
tf.keras.backend.clear_session()
res_log2 = log2(res)
self.generator = self.g_builder.grow(res_log2)
self.discriminator = self.d_builder.grow(res_log2)
self.current_res_log2 = res_log2
print(f"\nModel resolution:{res}x{res}")
def compile(
self, steps_per_epoch, phase, res, d_optimizer, g_optimizer, *args, **kwargs
):
self.loss_weights = kwargs.pop("loss_weights", self.loss_weights)
self.steps_per_epoch = steps_per_epoch
if res != 2 ** self.current_res_log2:
self.grow_model(res)
self.d_optimizer = d_optimizer
self.g_optimizer = g_optimizer
self.train_step_counter.assign(0)
self.phase = phase
self.d_loss_metric = keras.metrics.Mean(name="d_loss")
self.g_loss_metric = keras.metrics.Mean(name="g_loss")
super().compile(*args, **kwargs)
@property
def metrics(self):
return [self.d_loss_metric, self.g_loss_metric]
def generate_noise(self, batch_size):
noise = [
tf.random.normal((batch_size, 2 ** res, 2 ** res, 1))
for res in range(self.start_res_log2, self.target_res_log2 + 1)
]
return noise
def gradient_loss(self, grad):
loss = tf.square(grad)
loss = tf.reduce_sum(loss, axis=tf.range(1, tf.size(tf.shape(loss))))
loss = tf.sqrt(loss)
loss = tf.reduce_mean(tf.square(loss - 1))
return loss
def train_step(self, real_images):
self.train_step_counter.assign_add(1)
if self.phase == "TRANSITION":
self.alpha.assign(
tf.cast(self.train_step_counter / self.steps_per_epoch, tf.float32)
)
elif self.phase == "STABLE":
self.alpha.assign(1.0)
else:
raise NotImplementedError
alpha = tf.expand_dims(self.alpha, 0)
batch_size = tf.shape(real_images)[0]
real_labels = tf.ones(batch_size)
fake_labels = -tf.ones(batch_size)
z = tf.random.normal((batch_size, self.z_dim))
const_input = tf.ones(tuple([batch_size] + list(self.g_input_shape)))
noise = self.generate_noise(batch_size)
# generator
with tf.GradientTape() as g_tape:
w = self.mapping(z)
fake_images = self.generator([const_input, w, noise, alpha])
pred_fake = self.discriminator([fake_images, alpha])
g_loss = wasserstein_loss(real_labels, pred_fake)
trainable_weights = (
self.mapping.trainable_weights + self.generator.trainable_weights
)
gradients = g_tape.gradient(g_loss, trainable_weights)
self.g_optimizer.apply_gradients(zip(gradients, trainable_weights))
# discriminator
with tf.GradientTape() as gradient_tape, tf.GradientTape() as total_tape:
# forward pass
pred_fake = self.discriminator([fake_images, alpha])
pred_real = self.discriminator([real_images, alpha])
epsilon = tf.random.uniform((batch_size, 1, 1, 1))
interpolates = epsilon * real_images + (1 - epsilon) * fake_images
gradient_tape.watch(interpolates)
pred_fake_grad = self.discriminator([interpolates, alpha])
# calculate losses
loss_fake = wasserstein_loss(fake_labels, pred_fake)
loss_real = wasserstein_loss(real_labels, pred_real)
loss_fake_grad = wasserstein_loss(fake_labels, pred_fake_grad)
# gradient penalty
gradients_fake = gradient_tape.gradient(loss_fake_grad, [interpolates])
gradient_penalty = self.loss_weights[
"gradient_penalty"
] * self.gradient_loss(gradients_fake)
# drift loss
all_pred = tf.concat([pred_fake, pred_real], axis=0)
drift_loss = self.loss_weights["drift"] * tf.reduce_mean(all_pred ** 2)
d_loss = loss_fake + loss_real + gradient_penalty + drift_loss
gradients = total_tape.gradient(
d_loss, self.discriminator.trainable_weights
)
self.d_optimizer.apply_gradients(
zip(gradients, self.discriminator.trainable_weights)
)
# Update metrics
self.d_loss_metric.update_state(d_loss)
self.g_loss_metric.update_state(g_loss)
return {
"d_loss": self.d_loss_metric.result(),
"g_loss": self.g_loss_metric.result(),
}
def call(self, inputs: dict()):
style_code = inputs.get("style_code", None)
z = inputs.get("z", None)
noise = inputs.get("noise", None)
batch_size = inputs.get("batch_size", 1)
alpha = inputs.get("alpha", 1.0)
alpha = tf.expand_dims(alpha, 0)
if style_code is None:
if z is None:
z = tf.random.normal((batch_size, self.z_dim))
style_code = self.mapping(z)
if noise is None:
noise = self.generate_noise(batch_size)
# self.alpha.assign(alpha)
const_input = tf.ones(tuple([batch_size] + list(self.g_input_shape)))
images = self.generator([const_input, style_code, noise, alpha])
images = np.clip((images * 0.5 + 0.5) * 255, 0, 255).astype(np.uint8)
return images
我们首先在最小分辨率(例如 4x4 或 8x8)构建 StyleGAN。然后,我们通过附加新的生成器和鉴别器块,逐步将模型扩展到更高的分辨率。
START_RES = 4
TARGET_RES = 128
style_gan = StyleGAN(start_res=START_RES, target_res=TARGET_RES)
每个新分辨率的训练分为两个阶段 - “过渡”和“稳定”。在过渡阶段,来自先前分辨率的特征与当前分辨率的特征混合。这允许在扩展时实现更平滑的过渡。我们将 model.fit()
中的每个 epoch 视为一个阶段。
def train(
start_res=START_RES,
target_res=TARGET_RES,
steps_per_epoch=5000,
display_images=True,
):
opt_cfg = {"learning_rate": 1e-3, "beta_1": 0.0, "beta_2": 0.99, "epsilon": 1e-8}
val_batch_size = 16
val_z = tf.random.normal((val_batch_size, style_gan.z_dim))
val_noise = style_gan.generate_noise(val_batch_size)
start_res_log2 = int(np.log2(start_res))
target_res_log2 = int(np.log2(target_res))
for res_log2 in range(start_res_log2, target_res_log2 + 1):
res = 2 ** res_log2
for phase in ["TRANSITION", "STABLE"]:
if res == start_res and phase == "TRANSITION":
continue
train_dl = create_dataloader(res)
steps = int(train_step_ratio[res_log2] * steps_per_epoch)
style_gan.compile(
d_optimizer=tf.keras.optimizers.legacy.Adam(**opt_cfg),
g_optimizer=tf.keras.optimizers.legacy.Adam(**opt_cfg),
loss_weights={"gradient_penalty": 10, "drift": 0.001},
steps_per_epoch=steps,
res=res,
phase=phase,
run_eagerly=False,
)
prefix = f"res_{res}x{res}_{style_gan.phase}"
ckpt_cb = keras.callbacks.ModelCheckpoint(
f"checkpoints/stylegan_{res}x{res}.ckpt",
save_weights_only=True,
verbose=0,
)
print(phase)
style_gan.fit(
train_dl, epochs=1, steps_per_epoch=steps, callbacks=[ckpt_cb]
)
if display_images:
images = style_gan({"z": val_z, "noise": val_noise, "alpha": 1.0})
plot_images(images, res_log2)
StyleGAN 的训练可能需要很长时间,在下面的代码中,steps_per_epoch
的值设置为 1,用于检查代码是否正常运行。在实践中,需要更大的 steps_per_epoch
值(超过 10000)才能获得良好的结果。
train(start_res=4, target_res=16, steps_per_epoch=1, display_images=False)
Model resolution:4x4
STABLE
1/1 [==============================] - 3s 3s/step - d_loss: 2.0971 - g_loss: 2.5965
Model resolution:8x8
TRANSITION
1/1 [==============================] - 5s 5s/step - d_loss: 6.6954 - g_loss: 0.3432
STABLE
1/1 [==============================] - 4s 4s/step - d_loss: 3.3558 - g_loss: 3.7813
Model resolution:16x16
TRANSITION
1/1 [==============================] - 10s 10s/step - d_loss: 3.3166 - g_loss: 6.6047
STABLE
WARNING:tensorflow:5 out of the last 5 calls to <function Model.make_train_function.<locals>.train_function at 0x7f7f0e7005e0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://tensorflowcn.cn/guide/function#controlling_retracing and https://tensorflowcn.cn/api_docs/python/tf/function for more details.
WARNING:tensorflow:5 out of the last 5 calls to <function Model.make_train_function.<locals>.train_function at 0x7f7f0e7005e0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://tensorflowcn.cn/guide/function#controlling_retracing and https://tensorflowcn.cn/api_docs/python/tf/function for more details.
1/1 [==============================] - 8s 8s/step - d_loss: -6.1128 - g_loss: 17.0095
现在,我们可以使用预训练的 64x64 检查点运行一些推理。一般来说,图像保真度会随着分辨率的提高而提高。您可以尝试使用 CelebA HQ 数据集将此 StyleGAN 训练到高于 128x128 的分辨率。
url = "https://github.com/soon-yau/stylegan_keras/releases/download/keras_example_v1.0/stylegan_128x128.ckpt.zip"
weights_path = keras.utils.get_file(
"stylegan_128x128.ckpt.zip",
url,
extract=True,
cache_dir=os.path.abspath("."),
cache_subdir="pretrained",
)
style_gan.grow_model(128)
style_gan.load_weights(os.path.join("pretrained/stylegan_128x128.ckpt"))
tf.random.set_seed(196)
batch_size = 2
z = tf.random.normal((batch_size, style_gan.z_dim))
w = style_gan.mapping(z)
noise = style_gan.generate_noise(batch_size=batch_size)
images = style_gan({"style_code": w, "noise": noise, "alpha": 1.0})
plot_images(images, 5)
Downloading data from https://github.com/soon-yau/stylegan_keras/releases/download/keras_example_v1.0/stylegan_128x128.ckpt.zip
540540928/540534982 [==============================] - 30s 0us/step
我们还可以混合来自两张图像的风格来创建新的图像。
alpha = 0.4
w_mix = np.expand_dims(alpha * w[0] + (1 - alpha) * w[1], 0)
noise_a = [np.expand_dims(n[0], 0) for n in noise]
mix_images = style_gan({"style_code": w_mix, "noise": noise_a})
image_row = np.hstack([images[0], images[1], mix_images[0]])
plt.figure(figsize=(9, 3))
plt.imshow(image_row)
plt.axis("off")
(-0.5, 383.5, 127.5, -0.5)