开发者指南 / TensorFlow 多GPU 分布式训练

TensorFlow 多GPU 分布式训练

作者: fchollet
创建日期 2020/04/28
最后修改日期 2023/06/29
描述:使用 TensorFlow 对 Keras 模型进行多 GPU 训练的指南。

在 Colab 中查看 GitHub 源代码


简介

通常有两种方法可以将计算分布到多个设备上

数据并行,其中单个模型在多个设备或多台机器上进行复制。每个设备处理不同的数据批次,然后合并其结果。此设置存在许多变体,这些变体在不同的模型副本合并结果的方式、它们是否在每个批次都保持同步或它们是否更松散耦合等方面有所不同。

模型并行,其中单个模型的不同部分在不同的设备上运行,共同处理单个数据批次。这最适合具有自然并行架构的模型,例如具有多个分支的模型。

本指南重点介绍数据并行,特别是同步数据并行,其中模型的不同副本在处理每个批次后保持同步。同步性使模型收敛行为与您在单设备训练中看到的行为相同。

具体来说,本指南将教您如何使用 tf.distribute API 在多台 GPU 上训练 Keras 模型,对代码进行最少的更改,在安装在一台机器上的多台 GPU(通常为 2 到 16 个)上进行训练(单主机、多设备训练)。对于研究人员和小规模行业工作流程来说,这是最常见的设置。


设置

import os

os.environ["KERAS_BACKEND"] = "tensorflow"

import tensorflow as tf
import keras

单主机,多设备同步训练

在此设置中,您有一台机器在其上安装了多个 GPU(通常为 2 到 16 个)。每个设备都将运行您的模型副本(称为副本)。为简单起见,在下文中,我们将假设我们正在处理 8 个 GPU,并且不失一般性。

工作原理

在训练的每个步骤中

  • 当前的数据批次(称为全局批次)被分成 8 个不同的子批次(称为局部批次)。例如,如果全局批次有 512 个样本,则 8 个局部批次中的每一个都将有 64 个样本。
  • 8 个副本中的每一个都独立地处理一个局部批次:它们运行前向传递,然后运行后向传递,输出模型在局部批次上的损失相对于权重的梯度。
  • 源自局部梯度的权重更新在 8 个副本之间有效地合并。因为这是在每个步骤结束时完成的,所以副本始终保持同步。

在实践中,同步更新模型副本权重的过程是在每个单独的权重变量级别处理的。这是通过镜像变量对象完成的。

如何使用

要使用 Keras 模型进行单主机、多设备同步训练,可以使用 tf.distribute.MirroredStrategy API。以下是其工作原理

  • 实例化一个 MirroredStrategy,可以选择配置要使用的特定设备(默认情况下,该策略将使用所有可用的 GPU)。
  • 使用策略对象打开一个作用域,并在该作用域内创建所有包含变量的所需 Keras 对象。通常,这意味着在分布式作用域内创建和编译模型。在某些情况下,第一次调用 fit() 也可能创建变量,因此最好将 fit() 调用也放在作用域内。
  • 照常通过 fit() 训练模型。

重要的是,我们建议您使用 tf.data.Dataset 对象在多设备或分布式工作流中加载数据。

从结构上讲,它看起来像这样

# Create a MirroredStrategy.
strategy = tf.distribute.MirroredStrategy()
print('Number of devices: {}'.format(strategy.num_replicas_in_sync))

# Open a strategy scope.
with strategy.scope():
    # Everything that creates variables should be under the strategy scope.
    # In general this is only model construction & `compile()`.
    model = Model(...)
    model.compile(...)

    # Train the model on all available devices.
    model.fit(train_dataset, validation_data=val_dataset, ...)

    # Test the model on all available devices.
    model.evaluate(test_dataset)

这是一个简单的端到端可运行示例

def get_compiled_model():
    # Make a simple 2-layer densely-connected neural network.
    inputs = keras.Input(shape=(784,))
    x = keras.layers.Dense(256, activation="relu")(inputs)
    x = keras.layers.Dense(256, activation="relu")(x)
    outputs = keras.layers.Dense(10)(x)
    model = keras.Model(inputs, outputs)
    model.compile(
        optimizer=keras.optimizers.Adam(),
        loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=[keras.metrics.SparseCategoricalAccuracy()],
    )
    return model


def get_dataset():
    batch_size = 32
    num_val_samples = 10000

    # Return the MNIST dataset in the form of a [`tf.data.Dataset`](https://tensorflowcn.cn/api_docs/python/tf/data/Dataset).
    (x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

    # Preprocess the data (these are Numpy arrays)
    x_train = x_train.reshape(-1, 784).astype("float32") / 255
    x_test = x_test.reshape(-1, 784).astype("float32") / 255
    y_train = y_train.astype("float32")
    y_test = y_test.astype("float32")

    # Reserve num_val_samples samples for validation
    x_val = x_train[-num_val_samples:]
    y_val = y_train[-num_val_samples:]
    x_train = x_train[:-num_val_samples]
    y_train = y_train[:-num_val_samples]
    return (
        tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(batch_size),
        tf.data.Dataset.from_tensor_slices((x_val, y_val)).batch(batch_size),
        tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(batch_size),
    )


# Create a MirroredStrategy.
strategy = tf.distribute.MirroredStrategy()
print("Number of devices: {}".format(strategy.num_replicas_in_sync))

# Open a strategy scope.
with strategy.scope():
    # Everything that creates variables should be under the strategy scope.
    # In general this is only model construction & `compile()`.
    model = get_compiled_model()

    # Train the model on all available devices.
    train_dataset, val_dataset, test_dataset = get_dataset()
    model.fit(train_dataset, epochs=2, validation_data=val_dataset)

    # Test the model on all available devices.
    model.evaluate(test_dataset)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)
Number of devices: 1
Epoch 1/2
 1563/1563 ━━━━━━━━━━━━━━━━━━━━ 7s 4ms/step - loss: 0.3830 - sparse_categorical_accuracy: 0.8884 - val_loss: 0.1361 - val_sparse_categorical_accuracy: 0.9574
Epoch 2/2
 1563/1563 ━━━━━━━━━━━━━━━━━━━━ 9s 3ms/step - loss: 0.1068 - sparse_categorical_accuracy: 0.9671 - val_loss: 0.0894 - val_sparse_categorical_accuracy: 0.9724
 313/313 ━━━━━━━━━━━━━━━━━━━━ 1s 2ms/step - loss: 0.0988 - sparse_categorical_accuracy: 0.9673

使用回调确保容错

使用分布式训练时,应始终确保您有策略可以从故障中恢复(容错)。处理此问题的最简单方法是将 ModelCheckpoint 回调传递给 fit(),以便定期保存模型(例如,每 100 个批次或每个 epoch 保存一次)。然后,您可以从保存的模型重新开始训练。

这是一个简单的示例

# Prepare a directory to store all the checkpoints.
checkpoint_dir = "./ckpt"
if not os.path.exists(checkpoint_dir):
    os.makedirs(checkpoint_dir)


def make_or_restore_model():
    # Either restore the latest model, or create a fresh one
    # if there is no checkpoint available.
    checkpoints = [checkpoint_dir + "/" + name for name in os.listdir(checkpoint_dir)]
    if checkpoints:
        latest_checkpoint = max(checkpoints, key=os.path.getctime)
        print("Restoring from", latest_checkpoint)
        return keras.models.load_model(latest_checkpoint)
    print("Creating a new model")
    return get_compiled_model()


def run_training(epochs=1):
    # Create a MirroredStrategy.
    strategy = tf.distribute.MirroredStrategy()

    # Open a strategy scope and create/restore the model
    with strategy.scope():
        model = make_or_restore_model()

        callbacks = [
            # This callback saves a SavedModel every epoch
            # We include the current epoch in the folder name.
            keras.callbacks.ModelCheckpoint(
                filepath=checkpoint_dir + "/ckpt-{epoch}.keras",
                save_freq="epoch",
            )
        ]
        model.fit(
            train_dataset,
            epochs=epochs,
            callbacks=callbacks,
            validation_data=val_dataset,
            verbose=2,
        )


# Running the first time creates the model
run_training(epochs=1)

# Calling the same function again will resume from where we left off
run_training(epochs=1)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)
Creating a new model
1563/1563 - 7s - 4ms/step - loss: 0.2275 - sparse_categorical_accuracy: 0.9320 - val_loss: 0.1373 - val_sparse_categorical_accuracy: 0.9571
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)
Restoring from ./ckpt/ckpt-1.keras
1563/1563 - 6s - 4ms/step - loss: 0.0944 - sparse_categorical_accuracy: 0.9717 - val_loss: 0.0972 - val_sparse_categorical_accuracy: 0.9710

tf.data 性能提示

进行分布式训练时,加载数据的效率通常会变得至关重要。以下是一些确保 tf.data 管道运行速度尽可能快的提示。

关于数据集批处理的说明

创建数据集时,请确保使用全局批次大小进行批处理。例如,如果您的 8 个 GPU 中的每一个都能够运行一个包含 64 个样本的批次,则可以使用 512 的全局批次大小。

调用 dataset.cache()

如果对数据集调用 .cache(),则其数据将在第一次遍历数据后被缓存。每次后续迭代都将使用缓存的数据。缓存可以位于内存中(默认)或您指定的本地文件中。

当以下情况发生时,这可以提高性能:

  • 您的数据预计不会在迭代之间发生变化
  • 您正在从远程分布式文件系统读取数据
  • 您正在从本地磁盘读取数据,但您的数据可以放入内存中,并且您的工作流程严重依赖 I/O(例如,读取和解码图像文件)。

调用 dataset.prefetch(buffer_size)

您几乎始终应该在创建数据集后调用 .prefetch(buffer_size)。这意味着您的数据管道将与您的模型异步运行,新的样本将在预处理并存储在缓冲区中,而当前批次的样本将用于训练模型。下一个批次将在当前批次结束后被预取到 GPU 内存中。

就是这样!