ModelParallel
类keras.distribution.ModelParallel(layout_map=None, batch_dim_name=None, **kwargs)
将模型变量分片的分布。
与在所有设备上复制变量的 DataParallel
相比,ModelParallel
允许您除了输入数据外,还可以对变量进行分片。
要构建 ModelParallel
分布,您需要提供 DeviceMesh
和 LayoutMap
。
DeviceMesh
包含物理设备信息。网格中的轴名将用于映射变量和数据布局。LayoutMap
包含变量路径与其对应 TensorLayout
之间的映射。示例
devices = list_devices() # Assume there are 8 devices.
# Create a mesh with 2 devices for data parallelism and 4 devices for
# model parallelism.
device_mesh = DeviceMesh(shape=(2, 4), axis_names=('batch', 'model'),
devices=devices)
# Create a layout map that shard the `Dense` layer and `Conv2D`
# layer variables on the last dimension.
# Based on the `device_mesh`, this means the variables
# will be split across 4 devices. Any other variable that doesn't
# match any key in the layout map will be fully replicated.
layout_map = LayoutMap(device_mesh)
layout_map['dense.*kernel'] = (None, 'model')
layout_map['dense.*bias'] = ('model',)
layout_map['conv2d.*kernel'] = (None, None, None, 'model')
layout_map['conv2d.*bias'] = ('model',)
distribution = ModelParallel(
layout_map=layout_map,
batch_dim_name='batch',
)
# Set the global distribution, or via `with distribution.scope():`
set_distribution(distribution)
model = model_creation()
model.compile()
model.fit(data)
您可以快速更新设备网格形状以更改变量的分片因子。例如:
# With only the shape change for the device mesh, the variables will be
# sharded across 8 devices instead of 4, which further reduces the memory
# footprint of variables on each of the device.
device_mesh = DeviceMesh(
shape=(1, 8),
axis_names=('batch', 'model'),
devices=devices,
)
要找出所有模型变量的适当布局映射规则,您可以先列出所有模型变量路径,这些路径将用作将变量映射到 TensorLayout
的键。
例如:
model = create_model()
for v in model.variables:
print(v.path)
参数
LayoutMap
实例,将变量路径映射到相应的张量布局。layout_map
对象)中将用于分发数据的轴名。如果未指定,将使用设备网格中的第一个轴。