作者: fchollet
创建日期 2019/03/01
上次修改日期 2023/06/25
描述:在 TensorFlow 中编写低级训练和评估循环。
import time
import os
# This guide can only be run with the TensorFlow backend.
os.environ["KERAS_BACKEND"] = "tensorflow"
import tensorflow as tf
import keras
import numpy as np
Keras 提供了默认的训练和评估循环,即 fit()
和 evaluate()
。它们的使用在指南 使用内置方法进行训练和评估 中进行了介绍。
如果您想自定义模型的学习算法,同时仍然利用 fit()
的便利性(例如,使用 fit()
训练 GAN),您可以对 Model
类进行子类化并实现自己的 train_step()
方法,该方法在 fit()
让我们考虑一个简单的 MNIST 模型
def get_model():
inputs = keras.Input(shape=(784,), name="digits")
x1 = keras.layers.Dense(64, activation="relu")(inputs)
x2 = keras.layers.Dense(64, activation="relu")(x1)
outputs = keras.layers.Dense(10, name="predictions")(x2)
model = keras.Model(inputs=inputs, outputs=outputs)
return model
model = get_model()
# Instantiate an optimizer.
optimizer = keras.optimizers.Adam(learning_rate=1e-3)
# Instantiate a loss function.
loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
# Prepare the training dataset.
batch_size = 32
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = np.reshape(x_train, (-1, 784))
x_test = np.reshape(x_test, (-1, 784))
# Reserve 10,000 samples for validation.
x_val = x_train[-10000:]
y_val = y_train[-10000:]
x_train = x_train[:-10000]
y_train = y_train[:-10000]
# Prepare the training dataset.
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(batch_size)
# Prepare the validation dataset.
val_dataset = tf.data.Dataset.from_tensor_slices((x_val, y_val))
val_dataset = val_dataset.batch(batch_size)
在 GradientTape
范围内调用模型使您能够检索层可训练权重相对于损失值的梯度。使用优化器实例,您可以使用这些梯度来更新这些变量(您可以使用 model.trainable_weights
循环,该循环遍历各个 epoch。for
范围。epochs = 3
for epoch in range(epochs):
print(f"\nStart of epoch {epoch}")
# Iterate over the batches of the dataset.
for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
# Open a GradientTape to record the operations run
# during the forward pass, which enables auto-differentiation.
with tf.GradientTape() as tape:
# Run the forward pass of the layer.
# The operations that the layer applies
# to its inputs are going to be recorded
# on the GradientTape.
logits = model(x_batch_train, training=True) # Logits for this minibatch
# Compute the loss value for this minibatch.
loss_value = loss_fn(y_batch_train, logits)
# Use the gradient tape to automatically retrieve
# the gradients of the trainable variables with respect to the loss.
grads = tape.gradient(loss_value, model.trainable_weights)
# Run one step of gradient descent by updating
# the value of the variables to minimize the loss.
optimizer.apply(grads, model.trainable_weights)
# Log every 100 batches.
if step % 100 == 0:
f"Training loss (for 1 batch) at step {step}: {float(loss_value):.4f}"
print(f"Seen so far: {(step + 1) * batch_size} samples")
Start of epoch 0
Training loss (for 1 batch) at step 0: 95.3300
Seen so far: 32 samples
Training loss (for 1 batch) at step 100: 2.5622
Seen so far: 3232 samples
Training loss (for 1 batch) at step 200: 3.1138
Seen so far: 6432 samples
Training loss (for 1 batch) at step 300: 0.6748
Seen so far: 9632 samples
Training loss (for 1 batch) at step 400: 1.3308
Seen so far: 12832 samples
Training loss (for 1 batch) at step 500: 1.9813
Seen so far: 16032 samples
Training loss (for 1 batch) at step 600: 0.8640
Seen so far: 19232 samples
Training loss (for 1 batch) at step 700: 1.0696
Seen so far: 22432 samples
Training loss (for 1 batch) at step 800: 0.3662
Seen so far: 25632 samples
Training loss (for 1 batch) at step 900: 0.9556
Seen so far: 28832 samples
Training loss (for 1 batch) at step 1000: 0.7459
Seen so far: 32032 samples
Training loss (for 1 batch) at step 1100: 0.0468
Seen so far: 35232 samples
Training loss (for 1 batch) at step 1200: 0.7392
Seen so far: 38432 samples
Training loss (for 1 batch) at step 1300: 0.8435
Seen so far: 41632 samples
Training loss (for 1 batch) at step 1400: 0.3859
Seen so far: 44832 samples
Training loss (for 1 batch) at step 1500: 0.4156
Seen so far: 48032 samples
Start of epoch 1
Training loss (for 1 batch) at step 0: 0.4045
Seen so far: 32 samples
Training loss (for 1 batch) at step 100: 0.5983
Seen so far: 3232 samples
Training loss (for 1 batch) at step 200: 0.3154
Seen so far: 6432 samples
Training loss (for 1 batch) at step 300: 0.7911
Seen so far: 9632 samples
Training loss (for 1 batch) at step 400: 0.2607
Seen so far: 12832 samples
Training loss (for 1 batch) at step 500: 0.2303
Seen so far: 16032 samples
Training loss (for 1 batch) at step 600: 0.6048
Seen so far: 19232 samples
Training loss (for 1 batch) at step 700: 0.7041
Seen so far: 22432 samples
Training loss (for 1 batch) at step 800: 0.3669
Seen so far: 25632 samples
Training loss (for 1 batch) at step 900: 0.6389
Seen so far: 28832 samples
Training loss (for 1 batch) at step 1000: 0.7739
Seen so far: 32032 samples
Training loss (for 1 batch) at step 1100: 0.3888
Seen so far: 35232 samples
Training loss (for 1 batch) at step 1200: 0.8133
Seen so far: 38432 samples
Training loss (for 1 batch) at step 1300: 0.2034
Seen so far: 41632 samples
Training loss (for 1 batch) at step 1400: 0.0768
Seen so far: 44832 samples
Training loss (for 1 batch) at step 1500: 0.1544
Seen so far: 48032 samples
Start of epoch 2
Training loss (for 1 batch) at step 0: 0.1250
Seen so far: 32 samples
Training loss (for 1 batch) at step 100: 0.0152
Seen so far: 3232 samples
Training loss (for 1 batch) at step 200: 0.0917
Seen so far: 6432 samples
Training loss (for 1 batch) at step 300: 0.1330
Seen so far: 9632 samples
Training loss (for 1 batch) at step 400: 0.0884
Seen so far: 12832 samples
Training loss (for 1 batch) at step 500: 0.2656
Seen so far: 16032 samples
Training loss (for 1 batch) at step 600: 0.4375
Seen so far: 19232 samples
Training loss (for 1 batch) at step 700: 0.2246
Seen so far: 22432 samples
Training loss (for 1 batch) at step 800: 0.0748
Seen so far: 25632 samples
Training loss (for 1 batch) at step 900: 0.1765
Seen so far: 28832 samples
Training loss (for 1 batch) at step 1000: 0.0130
Seen so far: 32032 samples
Training loss (for 1 batch) at step 1100: 0.4030
Seen so far: 35232 samples
Training loss (for 1 batch) at step 1200: 0.0667
Seen so far: 38432 samples
Training loss (for 1 batch) at step 1300: 1.0553
Seen so far: 41632 samples
Training loss (for 1 batch) at step 1400: 0.6513
Seen so far: 44832 samples
Training loss (for 1 batch) at step 1500: 0.0599
Seen so far: 48032 samples
(通常在 epoch 结束时)。让我们使用这些知识在每个 epoch 结束时计算训练和验证数据上的 SparseCategoricalAccuracy
# Get a fresh model
model = get_model()
# Instantiate an optimizer to train the model.
optimizer = keras.optimizers.Adam(learning_rate=1e-3)
# Instantiate a loss function.
loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
# Prepare the metrics.
train_acc_metric = keras.metrics.SparseCategoricalAccuracy()
val_acc_metric = keras.metrics.SparseCategoricalAccuracy()
epochs = 2
for epoch in range(epochs):
print(f"\nStart of epoch {epoch}")
start_time = time.time()
# Iterate over the batches of the dataset.
for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
with tf.GradientTape() as tape:
logits = model(x_batch_train, training=True)
loss_value = loss_fn(y_batch_train, logits)
grads = tape.gradient(loss_value, model.trainable_weights)
optimizer.apply(grads, model.trainable_weights)
# Update training metric.
train_acc_metric.update_state(y_batch_train, logits)
# Log every 100 batches.
if step % 100 == 0:
f"Training loss (for 1 batch) at step {step}: {float(loss_value):.4f}"
print(f"Seen so far: {(step + 1) * batch_size} samples")
# Display metrics at the end of each epoch.
train_acc = train_acc_metric.result()
print(f"Training acc over epoch: {float(train_acc):.4f}")
# Reset training metrics at the end of each epoch
# Run a validation loop at the end of each epoch.
for x_batch_val, y_batch_val in val_dataset:
val_logits = model(x_batch_val, training=False)
# Update val metrics
val_acc_metric.update_state(y_batch_val, val_logits)
val_acc = val_acc_metric.result()
print(f"Validation acc: {float(val_acc):.4f}")
print(f"Time taken: {time.time() - start_time:.2f}s")
Start of epoch 0
Training loss (for 1 batch) at step 0: 89.1303
Seen so far: 32 samples
Training loss (for 1 batch) at step 100: 1.0351
Seen so far: 3232 samples
Training loss (for 1 batch) at step 200: 2.9143
Seen so far: 6432 samples
Training loss (for 1 batch) at step 300: 1.7842
Seen so far: 9632 samples
Training loss (for 1 batch) at step 400: 0.9583
Seen so far: 12832 samples
Training loss (for 1 batch) at step 500: 1.1100
Seen so far: 16032 samples
Training loss (for 1 batch) at step 600: 2.1144
Seen so far: 19232 samples
Training loss (for 1 batch) at step 700: 0.6801
Seen so far: 22432 samples
Training loss (for 1 batch) at step 800: 0.6202
Seen so far: 25632 samples
Training loss (for 1 batch) at step 900: 1.2570
Seen so far: 28832 samples
Training loss (for 1 batch) at step 1000: 0.3638
Seen so far: 32032 samples
Training loss (for 1 batch) at step 1100: 1.8402
Seen so far: 35232 samples
Training loss (for 1 batch) at step 1200: 0.7836
Seen so far: 38432 samples
Training loss (for 1 batch) at step 1300: 0.5147
Seen so far: 41632 samples
Training loss (for 1 batch) at step 1400: 0.4798
Seen so far: 44832 samples
Training loss (for 1 batch) at step 1500: 0.1653
Seen so far: 48032 samples
Training acc over epoch: 0.7961
Validation acc: 0.8825
Time taken: 46.06s
Start of epoch 1
Training loss (for 1 batch) at step 0: 1.3917
Seen so far: 32 samples
Training loss (for 1 batch) at step 100: 0.2600
Seen so far: 3232 samples
Training loss (for 1 batch) at step 200: 0.7206
Seen so far: 6432 samples
Training loss (for 1 batch) at step 300: 0.4987
Seen so far: 9632 samples
Training loss (for 1 batch) at step 400: 0.3410
Seen so far: 12832 samples
Training loss (for 1 batch) at step 500: 0.6788
Seen so far: 16032 samples
Training loss (for 1 batch) at step 600: 1.1355
Seen so far: 19232 samples
Training loss (for 1 batch) at step 700: 0.1762
Seen so far: 22432 samples
Training loss (for 1 batch) at step 800: 0.1801
Seen so far: 25632 samples
Training loss (for 1 batch) at step 900: 0.3515
Seen so far: 28832 samples
Training loss (for 1 batch) at step 1000: 0.4344
Seen so far: 32032 samples
Training loss (for 1 batch) at step 1100: 0.2027
Seen so far: 35232 samples
Training loss (for 1 batch) at step 1200: 0.4649
Seen so far: 38432 samples
Training loss (for 1 batch) at step 1300: 0.6848
Seen so far: 41632 samples
Training loss (for 1 batch) at step 1400: 0.4594
Seen so far: 44832 samples
Training loss (for 1 batch) at step 1500: 0.3548
Seen so far: 48032 samples
Training acc over epoch: 0.8896
Validation acc: 0.9094
Time taken: 43.49s
加速训练步骤TensorFlow 中的默认运行时是急切执行。因此,我们上面的训练循环急切地执行。
您可以将任何以张量作为输入的函数编译成静态图形。只需在其上添加 @tf.function
def train_step(x, y):
with tf.GradientTape() as tape:
logits = model(x, training=True)
loss_value = loss_fn(y, logits)
grads = tape.gradient(loss_value, model.trainable_weights)
optimizer.apply(grads, model.trainable_weights)
train_acc_metric.update_state(y, logits)
return loss_value
def test_step(x, y):
val_logits = model(x, training=False)
val_acc_metric.update_state(y, val_logits)
epochs = 2
for epoch in range(epochs):
print(f"\nStart of epoch {epoch}")
start_time = time.time()
# Iterate over the batches of the dataset.
for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
loss_value = train_step(x_batch_train, y_batch_train)
# Log every 100 batches.
if step % 100 == 0:
f"Training loss (for 1 batch) at step {step}: {float(loss_value):.4f}"
print(f"Seen so far: {(step + 1) * batch_size} samples")
# Display metrics at the end of each epoch.
train_acc = train_acc_metric.result()
print(f"Training acc over epoch: {float(train_acc):.4f}")
# Reset training metrics at the end of each epoch
# Run a validation loop at the end of each epoch.
for x_batch_val, y_batch_val in val_dataset:
test_step(x_batch_val, y_batch_val)
val_acc = val_acc_metric.result()
print(f"Validation acc: {float(val_acc):.4f}")
print(f"Time taken: {time.time() - start_time:.2f}s")
Start of epoch 0
Training loss (for 1 batch) at step 0: 0.5366
Seen so far: 32 samples
Training loss (for 1 batch) at step 100: 0.2732
Seen so far: 3232 samples
Training loss (for 1 batch) at step 200: 0.2478
Seen so far: 6432 samples
Training loss (for 1 batch) at step 300: 0.0263
Seen so far: 9632 samples
Training loss (for 1 batch) at step 400: 0.4845
Seen so far: 12832 samples
Training loss (for 1 batch) at step 500: 0.2239
Seen so far: 16032 samples
Training loss (for 1 batch) at step 600: 0.2242
Seen so far: 19232 samples
Training loss (for 1 batch) at step 700: 0.2122
Seen so far: 22432 samples
Training loss (for 1 batch) at step 800: 0.2856
Seen so far: 25632 samples
Training loss (for 1 batch) at step 900: 0.1957
Seen so far: 28832 samples
Training loss (for 1 batch) at step 1000: 0.2946
Seen so far: 32032 samples
Training loss (for 1 batch) at step 1100: 0.3080
Seen so far: 35232 samples
Training loss (for 1 batch) at step 1200: 0.2326
Seen so far: 38432 samples
Training loss (for 1 batch) at step 1300: 0.6514
Seen so far: 41632 samples
Training loss (for 1 batch) at step 1400: 0.2018
Seen so far: 44832 samples
Training loss (for 1 batch) at step 1500: 0.2812
Seen so far: 48032 samples
Training acc over epoch: 0.9104
Validation acc: 0.9199
Time taken: 5.73s
Start of epoch 1
Training loss (for 1 batch) at step 0: 0.3080
Seen so far: 32 samples
Training loss (for 1 batch) at step 100: 0.3943
Seen so far: 3232 samples
Training loss (for 1 batch) at step 200: 0.1657
Seen so far: 6432 samples
Training loss (for 1 batch) at step 300: 0.1463
Seen so far: 9632 samples
Training loss (for 1 batch) at step 400: 0.5359
Seen so far: 12832 samples
Training loss (for 1 batch) at step 500: 0.1894
Seen so far: 16032 samples
Training loss (for 1 batch) at step 600: 0.1801
Seen so far: 19232 samples
Training loss (for 1 batch) at step 700: 0.1724
Seen so far: 22432 samples
Training loss (for 1 batch) at step 800: 0.3997
Seen so far: 25632 samples
Training loss (for 1 batch) at step 900: 0.6017
Seen so far: 28832 samples
Training loss (for 1 batch) at step 1000: 0.1539
Seen so far: 32032 samples
Training loss (for 1 batch) at step 1100: 0.1078
Seen so far: 35232 samples
Training loss (for 1 batch) at step 1200: 0.8731
Seen so far: 38432 samples
Training loss (for 1 batch) at step 1300: 0.3110
Seen so far: 41632 samples
Training loss (for 1 batch) at step 1400: 0.6092
Seen so far: 44832 samples
Training loss (for 1 batch) at step 1500: 0.2046
Seen so far: 48032 samples
Training acc over epoch: 0.9189
Validation acc: 0.9358
Time taken: 3.17s
层和模型递归地跟踪在层的前向传递过程中创建的任何损失,这些层调用 self.add_loss(value)
。在完成前向传递后,可以通过属性 model.losses
class ActivityRegularizationLayer(keras.layers.Layer):
def call(self, inputs):
self.add_loss(1e-2 * tf.reduce_sum(inputs))
return inputs
inputs = keras.Input(shape=(784,), name="digits")
x = keras.layers.Dense(64, activation="relu")(inputs)
# Insert activity regularization as a layer
x = ActivityRegularizationLayer()(x)
x = keras.layers.Dense(64, activation="relu")(x)
outputs = keras.layers.Dense(10, name="predictions")(x)
model = keras.Model(inputs=inputs, outputs=outputs)
def train_step(x, y):
with tf.GradientTape() as tape:
logits = model(x, training=True)
loss_value = loss_fn(y, logits)
# Add any extra losses created during the forward pass.
loss_value += sum(model.losses)
grads = tape.gradient(loss_value, model.trainable_weights)
optimizer.apply(grads, model.trainable_weights)
train_acc_metric.update_state(y, logits)
return loss_value
最后,以下是一个简单的端到端示例,将您在本指南中学到的所有内容结合在一起:在 MNIST 数字上训练的 DCGAN。
您可能熟悉生成对抗网络 (GAN)。GAN 可以通过学习图像训练数据集的潜在分布(图像的“潜在空间”)来生成看起来几乎真实的新的图像。
GAN 由两部分组成:“生成器”模型,它将潜在空间中的点映射到图像空间中的点,“判别器”模型,一个分类器,可以区分真实图像(来自训练数据集)和伪造图像(生成器网络的输出)。
GAN 训练循环如下所示
1) 训练判别器。 - 在潜在空间中采样一批随机点。 - 通过“生成器”模型将这些点转换为伪造图像。 - 获取一批真实图像并将它们与生成的图像组合。 - 训练“判别器”模型以对生成的图像与真实图像进行分类。
2) 训练生成器。 - 在潜在空间中采样随机点。 - 通过“生成器”网络将这些点转换为伪造图像。 - 获取一批真实图像并将它们与生成的图像组合。 - 训练“生成器”模型以“欺骗”判别器并将伪造图像分类为真实图像。
有关 GAN 工作原理的更详细概述,请参阅 使用 Python 进行深度学习。
discriminator = keras.Sequential(
keras.Input(shape=(28, 28, 1)),
keras.layers.Conv2D(64, (3, 3), strides=(2, 2), padding="same"),
keras.layers.Conv2D(128, (3, 3), strides=(2, 2), padding="same"),
Model: "discriminator"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩ │ conv2d (Conv2D) │ (None, 14, 14, 64) │ 640 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ leaky_re_lu (LeakyReLU) │ (None, 14, 14, 64) │ 0 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ conv2d_1 (Conv2D) │ (None, 7, 7, 128) │ 73,856 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ leaky_re_lu_1 (LeakyReLU) │ (None, 7, 7, 128) │ 0 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ global_max_pooling2d │ (None, 128) │ 0 │ │ (GlobalMaxPooling2D) │ │ │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ dense_6 (Dense) │ (None, 1) │ 129 │ └─────────────────────────────────┴───────────────────────────┴────────────┘
Total params: 74,625 (291.50 KB)
Trainable params: 74,625 (291.50 KB)
Non-trainable params: 0 (0.00 B)
然后让我们创建一个生成器网络,它将潜在向量转换为形状为 (28, 28, 1)
的输出(表示 MNIST 数字)。
latent_dim = 128
generator = keras.Sequential(
# We want to generate 128 coefficients to reshape into a 7x7x128 map
keras.layers.Dense(7 * 7 * 128),
keras.layers.Reshape((7, 7, 128)),
keras.layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"),
keras.layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"),
keras.layers.Conv2D(1, (7, 7), padding="same", activation="sigmoid"),
以下是关键部分:训练循环。如您所见,它非常简单。训练步骤函数仅需 17 行代码。
# Instantiate one optimizer for the discriminator and another for the generator.
d_optimizer = keras.optimizers.Adam(learning_rate=0.0003)
g_optimizer = keras.optimizers.Adam(learning_rate=0.0004)
# Instantiate a loss function.
loss_fn = keras.losses.BinaryCrossentropy(from_logits=True)
def train_step(real_images):
# Sample random points in the latent space
random_latent_vectors = tf.random.normal(shape=(batch_size, latent_dim))
# Decode them to fake images
generated_images = generator(random_latent_vectors)
# Combine them with real images
combined_images = tf.concat([generated_images, real_images], axis=0)
# Assemble labels discriminating real from fake images
labels = tf.concat(
[tf.ones((batch_size, 1)), tf.zeros((real_images.shape[0], 1))], axis=0
# Add random noise to the labels - important trick!
labels += 0.05 * tf.random.uniform(labels.shape)
# Train the discriminator
with tf.GradientTape() as tape:
predictions = discriminator(combined_images)
d_loss = loss_fn(labels, predictions)
grads = tape.gradient(d_loss, discriminator.trainable_weights)
d_optimizer.apply(grads, discriminator.trainable_weights)
# Sample random points in the latent space
random_latent_vectors = tf.random.normal(shape=(batch_size, latent_dim))
# Assemble labels that say "all real images"
misleading_labels = tf.zeros((batch_size, 1))
# Train the generator (note that we should *not* update the weights
# of the discriminator)!
with tf.GradientTape() as tape:
predictions = discriminator(generator(random_latent_vectors))
g_loss = loss_fn(misleading_labels, predictions)
grads = tape.gradient(g_loss, generator.trainable_weights)
g_optimizer.apply(grads, generator.trainable_weights)
return d_loss, g_loss, generated_images
让我们训练我们的 GAN,通过重复在图像批次上调用 train_step
由于我们的判别器和生成器是卷积神经网络,因此您需要在 GPU 上运行此代码。
# Prepare the dataset. We use both the training & test MNIST digits.
batch_size = 64
(x_train, _), (x_test, _) = keras.datasets.mnist.load_data()
all_digits = np.concatenate([x_train, x_test])
all_digits = all_digits.astype("float32") / 255.0
all_digits = np.reshape(all_digits, (-1, 28, 28, 1))
dataset = tf.data.Dataset.from_tensor_slices(all_digits)
dataset = dataset.shuffle(buffer_size=1024).batch(batch_size)
epochs = 1 # In practice you need at least 20 epochs to generate nice digits.
save_dir = "./"
for epoch in range(epochs):
print(f"\nStart epoch {epoch}")
for step, real_images in enumerate(dataset):
# Train the discriminator & generator on one batch of real images.
d_loss, g_loss, generated_images = train_step(real_images)
# Logging.
if step % 100 == 0:
# Print metrics
print(f"discriminator loss at step {step}: {d_loss:.2f}")
print(f"adversarial loss at step {step}: {g_loss:.2f}")
# Save one generated image
img = keras.utils.array_to_img(generated_images[0] * 255.0, scale=False)
img.save(os.path.join(save_dir, f"generated_img_{step}.png"))
# To limit execution time we stop after 10 steps.
# Remove the lines below to actually train the model!
if step > 10:
Start epoch 0
discriminator loss at step 0: 0.69
adversarial loss at step 0: 0.69
就是这样!在 Colab GPU 上训练约 30 秒后,您将获得外观漂亮的伪造 MNIST 数字。