作者: fchollet
创建日期 2020/04/28
最后修改日期 2023/06/29
描述: 使用 TensorFlow 为 Keras 模型进行多 GPU 训练的指南。
通常有两种跨多个设备分配计算的方法
数据并行,即单个模型被复制到多个设备或多台机器上。每个设备处理不同的数据批次,然后合并结果。这种设置存在许多变体,它们在不同模型副本如何合并结果、它们是否在每个批次上保持同步,或者它们是否更松散地耦合等方面有所不同。
模型并行,即单个模型的不同部分在不同设备上运行,共同处理一个数据批次。这对于具有天然并行架构的模型(如具有多个分支的模型)最有效。
本指南侧重于数据并行,特别是同步数据并行,即模型的不同副本在处理每个批次后保持同步。同步性使模型的收敛行为与单设备训练时完全相同。
具体来说,本指南将教您如何使用 tf.distribute API 在多个 GPU 上训练 Keras 模型,只需对您的代码进行最少的更改,即可在单个机器上的多个 GPU(通常为 2 到 16 个)上进行训练(单主机、多设备训练)。这是研究人员和小型企业工作流程中最常见的设置。
import os
os.environ["KERAS_BACKEND"] = "tensorflow"
import tensorflow as tf
import keras
在此设置中,您有一台机器上安装了多个 GPU(通常为 2 到 16 个)。每个设备将运行您模型的一个副本(称为副本)。为简单起见,在下文中,我们将假设我们处理的是 8 个 GPU,这并不影响一般性。
工作原理
在训练的每一步
实际上,同步更新模型副本权重的过程在每个单独的权重变量的级别上处理。这是通过镜像变量对象实现的。
如何使用
要使用 Keras 模型进行单主机、多设备同步训练,您将使用 tf.distribute.MirroredStrategy API。其工作原理如下:
MirroredStrategy,可以选择配置您想要使用的特定设备(默认情况下,该策略将使用所有可用的 GPU)。fit() 也可能创建变量,因此将 fit() 调用放在作用域内也是个好主意。fit() 训练模型。重要的是,我们建议您使用 tf.data.Dataset 对象在多设备或分布式工作流中加载数据。
从概念上讲,它看起来像这样
# Create a MirroredStrategy.
strategy = tf.distribute.MirroredStrategy()
print('Number of devices: {}'.format(strategy.num_replicas_in_sync))
# Open a strategy scope.
with strategy.scope():
# Everything that creates variables should be under the strategy scope.
# In general this is only model construction & `compile()`.
model = Model(...)
model.compile(...)
# Train the model on all available devices.
model.fit(train_dataset, validation_data=val_dataset, ...)
# Test the model on all available devices.
model.evaluate(test_dataset)
这是一个简单的端到端可运行示例
def get_compiled_model():
# Make a simple 2-layer densely-connected neural network.
inputs = keras.Input(shape=(784,))
x = keras.layers.Dense(256, activation="relu")(inputs)
x = keras.layers.Dense(256, activation="relu")(x)
outputs = keras.layers.Dense(10)(x)
model = keras.Model(inputs, outputs)
model.compile(
optimizer=keras.optimizers.Adam(),
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
return model
def get_dataset():
batch_size = 32
num_val_samples = 10000
# Return the MNIST dataset in the form of a [`tf.data.Dataset`](https://tensorflowcn.cn/api_docs/python/tf/data/Dataset).
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
# Preprocess the data (these are Numpy arrays)
x_train = x_train.reshape(-1, 784).astype("float32") / 255
x_test = x_test.reshape(-1, 784).astype("float32") / 255
y_train = y_train.astype("float32")
y_test = y_test.astype("float32")
# Reserve num_val_samples samples for validation
x_val = x_train[-num_val_samples:]
y_val = y_train[-num_val_samples:]
x_train = x_train[:-num_val_samples]
y_train = y_train[:-num_val_samples]
return (
tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(batch_size),
tf.data.Dataset.from_tensor_slices((x_val, y_val)).batch(batch_size),
tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(batch_size),
)
# Create a MirroredStrategy.
strategy = tf.distribute.MirroredStrategy()
print("Number of devices: {}".format(strategy.num_replicas_in_sync))
# Open a strategy scope.
with strategy.scope():
# Everything that creates variables should be under the strategy scope.
# In general this is only model construction & `compile()`.
model = get_compiled_model()
# Train the model on all available devices.
train_dataset, val_dataset, test_dataset = get_dataset()
model.fit(train_dataset, epochs=2, validation_data=val_dataset)
# Test the model on all available devices.
model.evaluate(test_dataset)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)
Number of devices: 1
Epoch 1/2
1563/1563 ━━━━━━━━━━━━━━━━━━━━ 7s 4ms/step - loss: 0.3830 - sparse_categorical_accuracy: 0.8884 - val_loss: 0.1361 - val_sparse_categorical_accuracy: 0.9574
Epoch 2/2
1563/1563 ━━━━━━━━━━━━━━━━━━━━ 9s 3ms/step - loss: 0.1068 - sparse_categorical_accuracy: 0.9671 - val_loss: 0.0894 - val_sparse_categorical_accuracy: 0.9724
313/313 ━━━━━━━━━━━━━━━━━━━━ 1s 2ms/step - loss: 0.0988 - sparse_categorical_accuracy: 0.9673
在使用分布式训练时,您应始终确保有策略可以从故障中恢复(容错)。处理此问题的最简单方法是将 ModelCheckpoint 回调函数传递给 fit(),以定期(例如,每 100 个批次或每个 epoch)保存您的模型。然后,您可以从已保存的模型中恢复训练。
这是一个简单的例子:
# Prepare a directory to store all the checkpoints.
checkpoint_dir = "./ckpt"
if not os.path.exists(checkpoint_dir):
os.makedirs(checkpoint_dir)
def make_or_restore_model():
# Either restore the latest model, or create a fresh one
# if there is no checkpoint available.
checkpoints = [checkpoint_dir + "/" + name for name in os.listdir(checkpoint_dir)]
if checkpoints:
latest_checkpoint = max(checkpoints, key=os.path.getctime)
print("Restoring from", latest_checkpoint)
return keras.models.load_model(latest_checkpoint)
print("Creating a new model")
return get_compiled_model()
def run_training(epochs=1):
# Create a MirroredStrategy.
strategy = tf.distribute.MirroredStrategy()
# Open a strategy scope and create/restore the model
with strategy.scope():
model = make_or_restore_model()
callbacks = [
# This callback saves a SavedModel every epoch
# We include the current epoch in the folder name.
keras.callbacks.ModelCheckpoint(
filepath=checkpoint_dir + "/ckpt-{epoch}.keras",
save_freq="epoch",
)
]
model.fit(
train_dataset,
epochs=epochs,
callbacks=callbacks,
validation_data=val_dataset,
verbose=2,
)
# Running the first time creates the model
run_training(epochs=1)
# Calling the same function again will resume from where we left off
run_training(epochs=1)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)
Creating a new model
1563/1563 - 7s - 4ms/step - loss: 0.2275 - sparse_categorical_accuracy: 0.9320 - val_loss: 0.1373 - val_sparse_categorical_accuracy: 0.9571
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)
Restoring from ./ckpt/ckpt-1.keras
1563/1563 - 6s - 4ms/step - loss: 0.0944 - sparse_categorical_accuracy: 0.9717 - val_loss: 0.0972 - val_sparse_categorical_accuracy: 0.9710
tf.data 性能技巧在进行分布式训练时,加载数据的效率通常变得至关重要。以下是一些技巧,可确保您的 tf.data 管道尽可能快地运行。
关于数据集批处理的注意事项
在创建数据集时,请确保它已使用全局批次大小进行批处理。例如,如果您的 8 个 GPU 中的每个 GPU 能够处理 64 个样本的批次,则您可以使用 512 的全局批次大小。
调用 dataset.cache()
如果您对数据集调用 .cache(),那么在第一次迭代数据后,其数据将被缓存。每次后续迭代都将使用缓存的数据。缓存可以是在内存中(默认)或您指定的本地文件。
这可以提高性能,当:
调用 dataset.prefetch(buffer_size)
您几乎应该在创建数据集后调用 .prefetch(buffer_size)。这意味着您的数据管道将与模型异步运行,在当前批次样本用于训练模型时,新的样本将被预处理并存储在缓冲区中。当前批次结束后,下一个批次将由 GPU 内存预取。
就是这样!