LayoutMap
类keras.distribution.LayoutMap(device_mesh)
一个类似字典的对象,将字符串映射到 TensorLayout
实例。
LayoutMap
使用字符串作为键,使用 TensorLayout
作为值。普通 Python 字典和此类之间存在行为差异。字符串键在检索值时将被视为正则表达式。有关更多详细信息,请参阅 get
的文档字符串。
以下是一个使用示例。您可以定义 TensorLayout
的命名方案,然后检索相应的 TensorLayout
实例。
在正常情况下,查询的键通常是 variable.path
,它是变量的标识符。
作为快捷方式,在插入作为值时,也允许使用轴名称元组或列表,并将转换为 TensorLayout
。
layout_map = LayoutMap(device_mesh)
layout_map['dense.*kernel'] = (None, 'model')
layout_map['dense.*bias'] = ('model',)
layout_map['conv2d.*kernel'] = (None, None, None, 'model')
layout_map['conv2d.*bias'] = ('model',)
layout_1 = layout_map['dense_1.kernel'] # layout_1 == layout_2d
layout_2 = layout_map['dense_1.bias'] # layout_2 == layout_1d
layout_3 = layout_map['dense_2.kernel'] # layout_3 == layout_2d
layout_4 = layout_map['dense_2.bias'] # layout_4 == layout_1d
layout_5 = layout_map['my_model/conv2d_123/kernel'] # layout_5 == layout_4d
layout_6 = layout_map['my_model/conv2d_123/bias'] # layout_6 == layout_1d
layout_7 = layout_map['my_model/conv3d_1/kernel'] # layout_7 == None
layout_8 = layout_map['my_model/conv3d_1/bias'] # layout_8 == None
参数
keras.distribution.DeviceMesh
实例。DeviceMesh
类keras.distribution.DeviceMesh(shape, axis_names, devices=None)
用于分布式计算的计算设备集群。
此 API 与 jax.sharding.Mesh
和 tf.dtensor.Mesh
对齐,它们表示全局上下文中的计算设备。
在 jax.sharding.Mesh 和 tf.dtensor.Mesh 中查看更多详细信息。
参数
DeviceMesh
的形状,例如仅数据并行分布的 (8,)
或模型 + 数据并行分布的 (4, 2)
。DeviceMesh
中每个轴的逻辑名称。axis_names
的长度应与 shape
的秩匹配。在分布数据和变量时,axis_names
将用于匹配/创建 TensorLayout
。keras.distribution.list_devices()
中本地所有可用设备。TensorLayout
类keras.distribution.TensorLayout(axes, device_mesh=None)
应用于张量的布局。
此 API 与 jax.sharding.NamedSharding
和 tf.dtensor.Layout
对齐。
在 jax.sharding.NamedSharding 和 tf.dtensor.Layout 中查看更多详细信息。
参数
DeviceMesh
中的 axis_names
。对于不需要任何分片的任何维度,可以使用 None
作为占位符。DeviceMesh
,将用于创建布局。在指定网格之前,张量到物理设备的实际映射是未知的。distribute_tensor
函数keras.distribution.distribute_tensor(tensor, layout)
更改 jit 函数执行中张量值的布局。
参数
TensorLayout
。返回
具有指定张量布局的新值。