作者: Qianli Zhu
创建日期 2023/11/07
最后修改日期 2023/11/07
描述: 多后端 Keras 分布式 API 的完整指南。
Keras 分布式 API 是一个新的接口,旨在促进跨多种后端(如 JAX、TensorFlow 和 PyTorch)的分布式深度学习。这个强大的 API 引入了一套工具,支持数据并行和模型并行,从而可以在多个加速器和主机上高效地扩展深度学习模型。无论是利用 GPU 还是 TPU 的强大功能,该 API 都提供了一种简化的方法来初始化分布式环境、定义设备网格以及协调跨计算资源的张量布局。通过诸如 DataParallel
和 ModelParallel
之类的类,它抽象了并行计算中涉及的复杂性,使开发人员更容易加速其机器学习工作流程。
Keras 分布式 API 提供了一个全局编程模型,该模型允许开发人员在全局上下文中操作张量(如同使用单个设备一样),同时自动管理跨多个设备的分布。该 API 利用底层框架(例如 JAX)通过称为单程序多数据 (SPMD) 扩展的过程,根据分片指令分发程序和张量。
通过将应用程序与分片指令解耦,该 API 使得可以在单个设备、多个设备甚至多个客户端上运行同一应用程序,同时保留其全局语义。
import os
# The distribution API is only implemented for the JAX backend for now.
os.environ["KERAS_BACKEND"] = "jax"
import keras
from keras import layers
import jax
import numpy as np
from tensorflow import data as tf_data # For dataset input.
DeviceMesh
和 TensorLayout
Keras 分布式 API 中的 keras.distribution.DeviceMesh
类表示为分布式计算配置的计算设备集群。它与 jax.sharding.Mesh
和 tf.dtensor.Mesh
中的类似概念一致,在这些概念中,它用于将物理设备映射到逻辑网格结构。
然后,TensorLayout
类指定张量如何跨 DeviceMesh
分布,详细说明了张量沿与 DeviceMesh
中轴的名称相对应的指定轴的分片。
您可以在 TensorFlow DTensor 指南中找到更详细的概念说明。
# Retrieve the local available gpu devices.
devices = jax.devices("gpu") # Assume it has 8 local GPUs.
# Define a 2x4 device mesh with data and model parallel axes
mesh = keras.distribution.DeviceMesh(
shape=(2, 4), axis_names=["data", "model"], devices=devices
)
# A 2D layout, which describes how a tensor is distributed across the
# mesh. The layout can be visualized as a 2D grid with "model" as rows and
# "data" as columns, and it is a [4, 2] grid when it mapped to the physical
# devices on the mesh.
layout_2d = keras.distribution.TensorLayout(axes=("model", "data"), device_mesh=mesh)
# A 4D layout which could be used for data parallel of a image input.
replicated_layout_4d = keras.distribution.TensorLayout(
axes=("data", None, None, None), device_mesh=mesh
)
Keras 中的 Distribution
类是一个基础抽象类,旨在开发自定义分布策略。它封装了跨设备网格分布模型的变量、输入数据和中间计算所需的核心逻辑。作为最终用户,您不必直接与此类交互,而是与其子类(如 DataParallel
或 ModelParallel
)交互。
Keras 分布式 API 中的 DataParallel
类专为分布式训练中的数据并行策略而设计,其中模型权重在 DeviceMesh
中的所有设备上复制,并且每个设备处理一部分输入数据。
以下是此类的一个示例用法。
# Create DataParallel with list of devices.
# As a shortcut, the devices can be skipped,
# and Keras will detect all local available devices.
# E.g. data_parallel = DataParallel()
data_parallel = keras.distribution.DataParallel(devices=devices)
# Or you can choose to create DataParallel with a 1D `DeviceMesh`.
mesh_1d = keras.distribution.DeviceMesh(
shape=(8,), axis_names=["data"], devices=devices
)
data_parallel = keras.distribution.DataParallel(device_mesh=mesh_1d)
inputs = np.random.normal(size=(128, 28, 28, 1))
labels = np.random.normal(size=(128, 10))
dataset = tf_data.Dataset.from_tensor_slices((inputs, labels)).batch(16)
# Set the global distribution.
keras.distribution.set_distribution(data_parallel)
# Note that all the model weights from here on are replicated to
# all the devices of the `DeviceMesh`. This includes the RNG
# state, optimizer states, metrics, etc. The dataset fed into `model.fit` or
# `model.evaluate` will be split evenly on the batch dimension, and sent to
# all the devices. You don't have to do any manual aggregration of losses,
# since all the computation happens in a global context.
inputs = layers.Input(shape=(28, 28, 1))
y = layers.Flatten()(inputs)
y = layers.Dense(units=200, use_bias=False, activation="relu")(y)
y = layers.Dropout(0.4)(y)
y = layers.Dense(units=10, activation="softmax")(y)
model = keras.Model(inputs=inputs, outputs=y)
model.compile(loss="mse")
model.fit(dataset, epochs=3)
model.evaluate(dataset)
Epoch 1/3
8/8 ━━━━━━━━━━━━━━━━━━━━ 8s 30ms/step - loss: 1.0116
Epoch 2/3
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.9237
Epoch 3/3
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.8736
8/8 ━━━━━━━━━━━━━━━━━━━━ 5s 5ms/step - loss: 0.8349
0.842325747013092
ModelParallel
和 LayoutMap
当模型权重太大而无法容纳在单个加速器上时,ModelParallel
将非常有用。此设置允许您将模型权重或激活张量分散到 DeviceMesh
上的所有设备,并为大型模型启用水平缩放。
与所有权重都被完全复制的 DataParallel
模型不同,ModelParallel
下的权重布局通常需要进行一些自定义才能获得最佳性能。我们引入 LayoutMap
,让您可以从全局角度指定任何权重和中间张量的 TensorLayout
。
LayoutMap
是一个类似字典的对象,它将字符串映射到 TensorLayout
实例。它与普通的 Python 字典的行为不同,因为它在检索值时将字符串键视为正则表达式。该类允许您定义 TensorLayout
的命名模式,然后检索相应的 TensorLayout
实例。通常,用于查询的键是 variable.path
属性,它是变量的标识符。作为快捷方式,插入值时也允许使用轴名称的元组或列表,它将被转换为 TensorLayout
。
LayoutMap
还可以选择包含一个 DeviceMesh
,以在未设置 TensorLayout.device_mesh
时填充它。当使用键检索布局时,如果不存在完全匹配,则布局映射中的所有现有键都将被视为正则表达式,并再次与输入键匹配。如果有多个匹配项,则会引发 ValueError
。如果没有找到匹配项,则返回 None
。
mesh_2d = keras.distribution.DeviceMesh(
shape=(2, 4), axis_names=["data", "model"], devices=devices
)
layout_map = keras.distribution.LayoutMap(mesh_2d)
# The rule below means that for any weights that match with d1/kernel, it
# will be sharded with model dimensions (4 devices), same for the d1/bias.
# All other weights will be fully replicated.
layout_map["d1/kernel"] = (None, "model")
layout_map["d1/bias"] = ("model",)
# You can also set the layout for the layer output like
layout_map["d2/output"] = ("data", None)
model_parallel = keras.distribution.ModelParallel(layout_map, batch_dim_name="data")
keras.distribution.set_distribution(model_parallel)
inputs = layers.Input(shape=(28, 28, 1))
y = layers.Flatten()(inputs)
y = layers.Dense(units=200, use_bias=False, activation="relu", name="d1")(y)
y = layers.Dropout(0.4)(y)
y = layers.Dense(units=10, activation="softmax", name="d2")(y)
model = keras.Model(inputs=inputs, outputs=y)
# The data will be sharded across the "data" dimension of the method, which
# has 2 devices.
model.compile(loss="mse")
model.fit(dataset, epochs=3)
model.evaluate(dataset)
Epoch 1/3
/opt/conda/envs/keras-jax/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py:761: UserWarning: Some donated buffers were not usable: ShapedArray(float32[784,50]).
See an explanation at https://jax.net.cn/en/latest/faq.html#buffer-donation.
warnings.warn("Some donated buffers were not usable:"
8/8 ━━━━━━━━━━━━━━━━━━━━ 5s 8ms/step - loss: 1.0266
Epoch 2/3
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 0.9181
Epoch 3/3
8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 0.8725
8/8 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step - loss: 0.8381
0.8502610325813293
还可以通过调整网格的形状来轻松更改网格结构,以调整更多数据并行或模型并行之间的计算。并且任何其他代码都不需要进行任何更改。
full_data_parallel_mesh = keras.distribution.DeviceMesh(
shape=(8, 1), axis_names=["data", "model"], devices=devices
)
more_data_parallel_mesh = keras.distribution.DeviceMesh(
shape=(4, 2), axis_names=["data", "model"], devices=devices
)
more_model_parallel_mesh = keras.distribution.DeviceMesh(
shape=(2, 4), axis_names=["data", "model"], devices=devices
)
full_model_parallel_mesh = keras.distribution.DeviceMesh(
shape=(1, 8), axis_names=["data", "model"], devices=devices
)