Keras 3 API 文档 / 多设备分布式 / LayoutMap API

LayoutMap API

[来源]

LayoutMap

keras.distribution.LayoutMap(device_mesh)

一个类似字典的对象,将字符串映射到 TensorLayout 实例。

LayoutMap 使用字符串作为键,使用 TensorLayout 作为值。普通 Python 字典和此类之间存在行为差异。字符串键在检索值时将被视为正则表达式。有关更多详细信息,请参阅 get 的文档字符串。

以下是一个使用示例。您可以定义 TensorLayout 的命名方案,然后检索相应的 TensorLayout 实例。

在正常情况下,查询的键通常是 variable.path,它是变量的标识符。

作为快捷方式,在插入作为值时,也允许使用轴名称元组或列表,并将转换为 TensorLayout

layout_map = LayoutMap(device_mesh)
layout_map['dense.*kernel'] = (None, 'model')
layout_map['dense.*bias'] = ('model',)
layout_map['conv2d.*kernel'] = (None, None, None, 'model')
layout_map['conv2d.*bias'] = ('model',)

layout_1 = layout_map['dense_1.kernel']             # layout_1 == layout_2d
layout_2 = layout_map['dense_1.bias']               # layout_2 == layout_1d
layout_3 = layout_map['dense_2.kernel']             # layout_3 == layout_2d
layout_4 = layout_map['dense_2.bias']               # layout_4 == layout_1d
layout_5 = layout_map['my_model/conv2d_123/kernel'] # layout_5 == layout_4d
layout_6 = layout_map['my_model/conv2d_123/bias']   # layout_6 == layout_1d
layout_7 = layout_map['my_model/conv3d_1/kernel']   # layout_7 == None
layout_8 = layout_map['my_model/conv3d_1/bias']     # layout_8 == None

参数


[来源]

DeviceMesh

keras.distribution.DeviceMesh(shape, axis_names, devices=None)

用于分布式计算的计算设备集群。

此 API 与 jax.sharding.Meshtf.dtensor.Mesh 对齐,它们表示全局上下文中的计算设备。

jax.sharding.Meshtf.dtensor.Mesh 中查看更多详细信息。

参数

  • shape: 整数元组或列表。整个 DeviceMesh 的形状,例如仅数据并行分布的 (8,) 或模型 + 数据并行分布的 (4, 2)
  • axis_names: 字符串列表。DeviceMesh 中每个轴的逻辑名称。axis_names 的长度应与 shape 的秩匹配。在分布数据和变量时,axis_names 将用于匹配/创建 TensorLayout
  • devices: 可选的设备列表。默认为 keras.distribution.list_devices() 中本地所有可用设备。

[来源]

TensorLayout

keras.distribution.TensorLayout(axes, device_mesh=None)

应用于张量的布局。

此 API 与 jax.sharding.NamedShardingtf.dtensor.Layout 对齐。

jax.sharding.NamedShardingtf.dtensor.Layout 中查看更多详细信息。

参数

  • axes: 字符串元组,应映射到 DeviceMesh 中的 axis_names。对于不需要任何分片的任何维度,可以使用 None 作为占位符。
  • device_mesh: 可选的 DeviceMesh,将用于创建布局。在指定网格之前,张量到物理设备的实际映射是未知的。

[来源]

distribute_tensor 函数

keras.distribution.distribute_tensor(tensor, layout)

更改 jit 函数执行中张量值的布局。

参数

  • tensor: 要更改布局的张量。
  • layout: 要应用于值的 TensorLayout

返回

具有指定张量布局的新值。