开发者指南 / 使用 Keras 3 进行分布式训练

使用 Keras 3 进行分布式训练

作者: 朱乾利
创建日期 2023/11/07
最后修改日期 2023/11/07
描述:多后端 Keras 分布式 API 的完整指南。

在 Colab 中查看 GitHub 源代码


简介

Keras 分布式 API 是一个新的接口,旨在促进跨各种后端(如 JAX、TensorFlow 和 PyTorch)的分布式深度学习。这个强大的 API 引入了一套工具,可以实现数据和模型并行,从而能够在多个加速器和主机上有效地扩展深度学习模型。无论利用 GPU 还是 TPU 的强大功能,该 API 都提供了一种简化的方式来初始化分布式环境、定义设备网格以及协调张量在计算资源上的布局。通过 `DataParallel` 和 `ModelParallel` 等类,它抽象了并行计算中涉及的复杂性,使开发人员更容易加速其机器学习工作流程。


工作原理

Keras 分布式 API 提供了一个全局编程模型,允许开发人员构建在全局上下文中操作张量的应用程序(就像使用单个设备一样),同时自动管理跨多个设备的分布。该 API 利用底层框架(例如 JAX)根据分片指令(通过称为单程序多数据 (SPMD) 扩展的过程)来分布程序和张量。

通过将应用程序与分片指令解耦,API 能够在单个设备、多个设备甚至多个客户端上运行相同的应用程序,同时保留其全局语义。


设置

import os

# The distribution API is only implemented for the JAX backend for now.
os.environ["KERAS_BACKEND"] = "jax"

import keras
from keras import layers
import jax
import numpy as np
from tensorflow import data as tf_data  # For dataset input.

DeviceMeshTensorLayout

Keras 分布式 API 中的 keras.distribution.DeviceMesh 类表示为分布式计算配置的计算设备集群。它与 jax.sharding.Meshtf.dtensor.Mesh 中的类似概念相一致,在这些概念中,它用于将物理设备映射到逻辑网格结构。

然后,TensorLayout 类指定了张量如何在 DeviceMesh 上分布,详细说明了沿指定轴对张量进行分片,这些轴对应于 DeviceMesh 中轴的名称。

您可以在 TensorFlow DTensor 指南 中找到更详细的概念解释。

# Retrieve the local available gpu devices.
devices = jax.devices("gpu")  # Assume it has 8 local GPUs.

# Define a 2x4 device mesh with data and model parallel axes
mesh = keras.distribution.DeviceMesh(
    shape=(2, 4), axis_names=["data", "model"], devices=devices
)

# A 2D layout, which describes how a tensor is distributed across the
# mesh. The layout can be visualized as a 2D grid with "model" as rows and
# "data" as columns, and it is a [4, 2] grid when it mapped to the physical
# devices on the mesh.
layout_2d = keras.distribution.TensorLayout(axes=("model", "data"), device_mesh=mesh)

# A 4D layout which could be used for data parallel of a image input.
replicated_layout_4d = keras.distribution.TensorLayout(
    axes=("data", None, None, None), device_mesh=mesh
)

分布

Keras 中的 Distribution 类充当基础抽象类,用于开发自定义分布策略。它封装了在设备网格上分布模型变量、输入数据和中间计算所需的核心逻辑。作为最终用户,您无需直接与该类交互,但可以使用它的子类,如 `DataParallel` 或 `ModelParallel`。


DataParallel

Keras 分布式 API 中的 DataParallel 类专为分布式训练中的数据并行策略而设计,其中模型权重在 DeviceMesh 中的所有设备上进行复制,每个设备处理一部分输入数据。

以下是如何使用该类的示例。

# Create DataParallel with list of devices.
# As a shortcut, the devices can be skipped,
# and Keras will detect all local available devices.
# E.g. data_parallel = DataParallel()
data_parallel = keras.distribution.DataParallel(devices=devices)

# Or you can choose to create DataParallel with a 1D `DeviceMesh`.
mesh_1d = keras.distribution.DeviceMesh(
    shape=(8,), axis_names=["data"], devices=devices
)
data_parallel = keras.distribution.DataParallel(device_mesh=mesh_1d)

inputs = np.random.normal(size=(128, 28, 28, 1))
labels = np.random.normal(size=(128, 10))
dataset = tf_data.Dataset.from_tensor_slices((inputs, labels)).batch(16)

# Set the global distribution.
keras.distribution.set_distribution(data_parallel)

# Note that all the model weights from here on are replicated to
# all the devices of the `DeviceMesh`. This includes the RNG
# state, optimizer states, metrics, etc. The dataset fed into `model.fit` or
# `model.evaluate` will be split evenly on the batch dimension, and sent to
# all the devices. You don't have to do any manual aggregration of losses,
# since all the computation happens in a global context.
inputs = layers.Input(shape=(28, 28, 1))
y = layers.Flatten()(inputs)
y = layers.Dense(units=200, use_bias=False, activation="relu")(y)
y = layers.Dropout(0.4)(y)
y = layers.Dense(units=10, activation="softmax")(y)
model = keras.Model(inputs=inputs, outputs=y)

model.compile(loss="mse")
model.fit(dataset, epochs=3)
model.evaluate(dataset)
Epoch 1/3
 8/8 ━━━━━━━━━━━━━━━━━━━━ 8s 30ms/step - loss: 1.0116
Epoch 2/3
 8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.9237
Epoch 3/3
 8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.8736
 8/8 ━━━━━━━━━━━━━━━━━━━━ 5s 5ms/step - loss: 0.8349

0.842325747013092

ModelParallelLayoutMap

当模型权重过大而无法容纳在单个加速器上时,ModelParallel 将非常有用。此设置允许您将模型权重或激活张量分布到 DeviceMesh 上的所有设备上,并为大型模型启用横向扩展。

与所有权重都完全复制的 DataParallel 模型不同,ModelParallel 下的权重布局通常需要一些自定义才能获得最佳性能。我们引入 LayoutMap 以便您可以从全局角度指定任何权重和中间张量的 TensorLayout

LayoutMap 是一个类似字典的对象,它将字符串映射到 TensorLayout 实例。它的行为与普通的 Python 字典不同,因为在检索值时,字符串键被视为正则表达式。此类允许您定义 TensorLayout 的命名方案,然后检索相应的 TensorLayout 实例。通常,用于查询的键是 variable.path 属性,它是变量的标识符。作为快捷方式,在插入值时也允许使用元组或轴名称列表,并且它将被转换为 TensorLayout

LayoutMap 还可以选择包含一个 DeviceMesh,以在未设置 TensorLayout.device_mesh 时填充它。当使用键检索布局时,如果没有完全匹配,则布局映射中的所有现有键将再次被视为正则表达式并与输入键进行匹配。如果有多个匹配项,则会引发 ValueError。如果找不到任何匹配项,则返回 None

mesh_2d = keras.distribution.DeviceMesh(
    shape=(2, 4), axis_names=["data", "model"], devices=devices
)
layout_map = keras.distribution.LayoutMap(mesh_2d)
# The rule below means that for any weights that match with d1/kernel, it
# will be sharded with model dimensions (4 devices), same for the d1/bias.
# All other weights will be fully replicated.
layout_map["d1/kernel"] = (None, "model")
layout_map["d1/bias"] = ("model",)

# You can also set the layout for the layer output like
layout_map["d2/output"] = ("data", None)

model_parallel = keras.distribution.ModelParallel(layout_map, batch_dim_name="data")

keras.distribution.set_distribution(model_parallel)

inputs = layers.Input(shape=(28, 28, 1))
y = layers.Flatten()(inputs)
y = layers.Dense(units=200, use_bias=False, activation="relu", name="d1")(y)
y = layers.Dropout(0.4)(y)
y = layers.Dense(units=10, activation="softmax", name="d2")(y)
model = keras.Model(inputs=inputs, outputs=y)

# The data will be sharded across the "data" dimension of the method, which
# has 2 devices.
model.compile(loss="mse")
model.fit(dataset, epochs=3)
model.evaluate(dataset)
Epoch 1/3

/opt/conda/envs/keras-jax/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py:761: UserWarning: Some donated buffers were not usable: ShapedArray(float32[784,50]).
See an explanation at https://jax.net.cn/en/latest/faq.html#buffer-donation.
  warnings.warn("Some donated buffers were not usable:"

 8/8 ━━━━━━━━━━━━━━━━━━━━ 5s 8ms/step - loss: 1.0266
Epoch 2/3
 8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 0.9181
Epoch 3/3
 8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 0.8725
 8/8 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step - loss: 0.8381  

0.8502610325813293

更改网格结构以调整更多数据并行或模型并行之间的计算也很容易。您可以通过调整网格的形状来做到这一点。并且任何其他代码都不需要更改。

full_data_parallel_mesh = keras.distribution.DeviceMesh(
    shape=(8, 1), axis_names=["data", "model"], devices=devices
)
more_data_parallel_mesh = keras.distribution.DeviceMesh(
    shape=(4, 2), axis_names=["data", "model"], devices=devices
)
more_model_parallel_mesh = keras.distribution.DeviceMesh(
    shape=(2, 4), axis_names=["data", "model"], devices=devices
)
full_model_parallel_mesh = keras.distribution.DeviceMesh(
    shape=(1, 8), axis_names=["data", "model"], devices=devices
)

进一步阅读

  1. JAX 分布式数组和自动并行化
  2. JAX 分片模块
  3. 使用 DTensor 进行 TensorFlow 分布式训练
  4. TensorFlow DTensor 概念
  5. 将 DTensor 与 tf.keras 结合使用