KerasRS / API文档 / Embedding 层 / DistributedEmbedding 层

DistributedEmbedding 层

[源代码]

DistributedEmbedding

keras_rs.layers.DistributedEmbedding(
    feature_configs: Union[
        keras_rs.src.layers.embedding.distributed_embedding_config.FeatureConfig,
        tensorflow.python.tpu.tpu_embedding_v2_utils.FeatureConfig,
        Sequence[
            Union[
                keras_rs.src.layers.embedding.distributed_embedding_config.FeatureConfig,
                tensorflow.python.tpu.tpu_embedding_v2_utils.FeatureConfig,
                ForwardRef("Nested[T]"),
            ]
        ],
        Mapping[
            str,
            Union[
                keras_rs.src.layers.embedding.distributed_embedding_config.FeatureConfig,
                tensorflow.python.tpu.tpu_embedding_v2_utils.FeatureConfig,
                ForwardRef("Nested[T]"),
            ],
        ],
    ],
    table_stacking: Union[str, Sequence[str], Sequence[Sequence[str]]] = "auto",
    **kwargs: Any
)

DistributedEmbedding,一个用于加速大型嵌入查找的层。


注意:DistributedEmbedding 处于预览阶段。


DistributedEmbedding 是一个针对具有 SparseCore 的 TPU 芯片进行优化的层,可以显著提高嵌入查找和嵌入训练的速度。它通过将多个查找合并为一个调用,并将嵌入表分片到可用芯片上来实现。请注意,只有当嵌入表足够大以至于无法装入单个芯片并需要分片时,才能看到性能提升。更多详细信息请参见下面的“放置”部分。

在其他硬件上,如 GPU、CPU 和没有 SparseCore 的 TPU 上,DistributedEmbedding 提供相同的 API,但没有特定的加速。除了通过 keras.distribution.set_distribution 设置的分片方案外,不应用任何特定的分片方案。

DistributedEmbedding 嵌入输入序列,并通过应用可配置的组合器函数将它们缩减为单个嵌入。

配置

特征和表

DistributedEmbedding 嵌入层通过一组 keras_rs.layers.FeatureConfig 对象进行配置,这些对象本身引用 keras_rs.layers.TableConfig 对象。

  • TableConfig 定义了一个嵌入表,其中包含词汇表大小、嵌入维度以及用于组合和训练的优化器等参数。
  • FeatureConfig 定义了 DistributedEmbedding 将处理的输入特征以及要使用的嵌入表。请注意,多个特征可以使用同一个嵌入表。
table1 = keras_rs.layers.TableConfig(
    name="table1",
    vocabulary_size=TABLE1_VOCABULARY_SIZE,
    embedding_dim=TABLE1_EMBEDDING_SIZE,
    placement="auto",
)
table2 = keras_rs.layers.TableConfig(
    name="table2",
    vocabulary_size=TABLE2_VOCABULARY_SIZE,
    embedding_dim=TABLE2_EMBEDDING_SIZE,
    placement="auto",
)

feature1 = keras_rs.layers.FeatureConfig(
    name="feature1",
    table=table1,
    input_shape=(GLOBAL_BATCH_SIZE,),
    output_shape=(GLOBAL_BATCH_SIZE, TABLE1_EMBEDDING_SIZE),
)
feature2 = keras_rs.layers.FeatureConfig(
    name="feature2",
    table=table2,
    input_shape=(GLOBAL_BATCH_SIZE,),
    output_shape=(GLOBAL_BATCH_SIZE, TABLE2_EMBEDDING_SIZE),
)

feature_configs = {
    "feature1": feature1,
    "feature2": feature2,
}

embedding = keras_rs.layers.DistributedEmbedding(feature_configs)

优化器

DistributedEmbedding 中的每个嵌入表都使用自己的训练优化器,这与在模型中通过 model.compile() 设置的优化器是独立的。

请注意,并非所有优化器都受支持。目前,以下优化器在所有后端和加速器上都受支持:

此外,并非所有优化器参数都受支持(例如 SGDnesterov 选项)。当使用不受支持的优化器或不受支持的优化器参数时,会引发错误。

放置

DistributedEmbedding 中的每个嵌入表都可以放置在 SparseCore 芯片上,或者放置在加速器的默认设备上(例如 TPU 上的 Tensor Cores 的 HBM)。这由 keras_rs.layers.TableConfigplacement 属性控制。

  • 放置为 "sparsecore" 表示该表应放置在 SparseCore 芯片上。如果选择此选项但没有 SparseCore 芯片,则会引发错误。
  • 放置为 "default_device" 表示该表不应放置在 SparseCore 上,即使有。相反,该表放置在模型通常放置的设备上,即 TPU 和 GPU 上的 HBM。在这种情况下,如果适用,该表将使用通过 keras.distribution.set_distribution 设置的分片方案进行分片。在 GPU、CPU 和没有 SparseCore 的 TPU 上,这是唯一可用的放置方式,也是 "auto" 选择的方式。
  • 放置为 "auto" 表示如果可用,则使用 "sparsecore",否则使用 "default_device"。这是未指定时的默认值。

优化 TPU 性能

  • 需要分片的大型表应使用 "sparsecore" 放置。
  • 足够小的表应使用 "default_device" 放置,并且通常应通过使用 keras.distribution.DataParallel 分片选项在 TPU 上进行复制。

与 TensorFlow 在 TPU 上配合 SparseCore 使用

输入

除了 tf.Tensor 外,DistributedEmbedding 还接受 tf.RaggedTensortf.SparseTensor 作为嵌入查找的输入。Ragged 텐서 必须在索引为 1 的维度上是 ragged 的。请注意,如果提供了权重,则每个权重张量必须与该特定特征的输入属于同一类,并且对于 ragged 텐서 使用完全相同的 ragged 行长度,对于 sparse 텐서 使用相同的索引。DistributedEmbedding 的所有输出都是密集张量。

设置

要在 TPU 上使用 DistributedEmbedding 和 TensorFlow,必须使用 tf.distribute.TPUStrategyDistributedEmbedding 层必须在 TPUStrategy 下创建。

resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu="local")
topology = tf.tpu.experimental.initialize_tpu_system(resolver)
device_assignment = tf.tpu.experimental.DeviceAssignment.build(
    topology, num_replicas=resolver.get_tpu_system_metadata().num_cores
)
strategy = tf.distribute.TPUStrategy(
    resolver, experimental_device_assignment=device_assignment
)

with strategy.scope():
    embedding = keras_rs.layers.DistributedEmbedding(feature_configs)

在 Keras 模型中的用法

要使用 Keras 的 model.fit(),必须在 TPUStrategy 下编译模型。然后,可以直接调用 model.fit()model.evaluate()model.predict()。Keras 模型负责使用策略运行模型,并自动分片数据集。

with strategy.scope():
    embedding = keras_rs.layers.DistributedEmbedding(feature_configs)
    model = create_model(embedding)
    model.compile(loss=keras.losses.MeanSquaredError(), optimizer="adam")

model.fit(dataset, epochs=10)

直接调用

DistributedEmbedding 必须通过嵌套在 tf.function 中的 strategy.run 调用来调用。

@tf.function
def embedding_wrapper(tf_fn_inputs, tf_fn_weights=None):
    def strategy_fn(st_fn_inputs, st_fn_weights):
        return embedding(st_fn_inputs, st_fn_weights)

    return strategy.run(strategy_fn, args=(tf_fn_inputs, tf_fn_weights)))

embedding_wrapper(my_inputs, my_weights)

使用数据集时,必须对数据集进行分片。然后可以将迭代器传递给使用 strategy.runtf.function

dataset = strategy.experimental_distribute_dataset(dataset)

@tf.function
def run_loop(iterator):
    def step(data):
        (inputs, weights), labels = data
        with tf.GradientTape() as tape:
            result = embedding(inputs, weights)
            loss = keras.losses.mean_squared_error(labels, result)
        tape.gradient(loss, embedding.trainable_variables)
        return result

    for _ in tf.range(4):
        result = strategy.run(step, args=(next(iterator),))

run_loop(iter(dataset))

与 JAX 在 TPU 上配合 SparseCore 使用

设置

要在 TPU 上使用 DistributedEmbedding 和 JAX,必须创建一个 Keras Distribution 并进行设置。

distribution = keras.distribution.DataParallel(devices=jax.device("tpu"))
keras.distribution.set_distribution(distribution)

输入

对于 JAX,输入可以是密集张量,也可以是 ragged(嵌套)的 NumPy 数组。要启用 jit_compile = True,必须在输入上显式调用 layer.preprocess(...),然后将预处理后的输出馈送给模型。有关详细信息,请参阅下一节关于预处理的内容。

Ragged 输入数组必须在索引为 1 的维度上是 ragged 的。请注意,如果提供了权重,则每个权重张量必须与该特定特征的输入属于同一类,并且对于 ragged 数组使用完全相同的 ragged 行长度。DistributedEmbedding 的所有输出都是密集张量。

预处理

在 JAX 中,SparseCore 的使用需要特殊格式的数据,该数据取决于可用硬件的属性。此数据重格式化目前不支持即时编译 (jit-compilation),因此必须在将数据传递给模型之前进行应用。

预处理适用于密集或 ragged 的 NumPy 数组,或者可转换为密集或 ragged NumPy 数组的张量,例如 tf.RaggedTensor

添加预处理的一种简单方法是通过使用 Python 生成器将其附加到输入管道。

# Create the embedding layer.
embedding_layer = DistributedEmbedding(feature_configs)

# Add preprocessing to a data input pipeline.
def preprocessed_dataset_generator(dataset):
    for (inputs, weights), labels in iter(dataset):
        yield embedding_layer.preprocess(
            inputs, weights, training=True
        ), labels

preprocessed_train_dataset = preprocessed_dataset_generator(train_dataset)

此显式的预处理阶段会合并输入和可选的权重,因此新数据可以直接传递到层或模型的 inputs 参数。

注意:在多主机环境中进行数据并行处理时,需要将数据正确地分片到各个主机。如果原始数据集是 tf.data.Dataset 类型,则需要在应用预处理生成器之前手动对其进行分片。

# Manually shard the dataset across hosts.
train_dataset = distribution.distribute_dataset(train_dataset)
distribution.auto_shard_dataset = False  # Dataset is already sharded.

# Add a preprocessing stage to the distributed data input pipeline.
train_dataset = preprocessed_dataset_generator(train_dataset)

如果原始数据集不是 tf.data.Dataset,则它必须已经跨主机进行了预分片。

在 Keras 模型中的用法

一旦设置了全局分片并定义了输入预处理管道,模型训练就可以正常进行。例如:

# Construct, compile, and fit the model using the preprocessed data.
model = keras.Sequential(
  [
    embedding_layer,
    keras.layers.Dense(2),
    keras.layers.Dense(3),
    keras.layers.Dense(4),
  ]
)
model.compile(optimizer="adam", loss="mse", jit_compile=True)
model.fit(preprocessed_train_dataset, epochs=10)

直接调用

DistributedEmbedding 层也可以直接调用。使用即时编译时需要显式预处理。

# Call the layer directly.
activations = embedding_layer(my_inputs, my_weights)

# Call the layer with JIT compilation and explicitly preprocessed inputs.
embedding_layer_jit = jax.jit(embedding_layer)
preprocessed_inputs = embedding_layer.preprocess(my_inputs, my_weights)
activations = embedding_layer_jit(preprocessed_inputs)

同样,对于自定义训练循环,必须在将数据传递到即时编译的训练步骤之前应用预处理。

# Create an optimizer and loss function.
optimizer = keras.optimizers.Adam(learning_rate=1e-3)

def loss_and_updates(trainable_variables, non_trainable_variables, x, y):
    y_pred, non_trainable_variables = model.stateless_call(
        trainable_variables, non_trainable_variables, x, training=True
    )
    loss = keras.losses.mean_squared_error(y, y_pred)
    return loss, non_trainable_variables

grad_fn = jax.value_and_grad(loss_and_updates, has_aux=True)

# Create a JIT-compiled training step.
@jax.jit
def train_step(state, x, y):
    (
      trainable_variables,
      non_trainable_variables,
      optimizer_variables,
    ) = state
    (loss, non_trainable_variables), grads = grad_fn(
        trainable_variables, non_trainable_variables, x, y
    )
    trainable_variables, optimizer_variables = optimizer.stateless_apply(
        optimizer_variables, grads, trainable_variables
    )
    return loss, (
        trainable_variables,
        non_trainable_variables,
        optimizer_variables,
    )

# Build optimizer variables.
optimizer.build(model.trainable_variables)

# Assemble the training state.
trainable_variables = model.trainable_variables
non_trainable_variables = model.non_trainable_variables
optimizer_variables = optimizer.variables
state = trainable_variables, non_trainable_variables, optimizer_variables

# Training loop.
for (inputs, weights), labels in train_dataset:
    # Explicitly preprocess the data.
    preprocessed_inputs = embedding_layer.preprocess(inputs, weights)
    loss, state = train_step(state, preprocessed_inputs, labels)

参数

  • feature_configs:一个嵌套结构的 keras_rs.layers.FeatureConfig
  • table_stacking:要使用的表堆叠。None 表示不进行表堆叠。"auto" 表示自动堆叠表。表名列表或表名列表的列表表示将内部列表中的表堆叠在一起。请注意,较旧的 TPU 不支持表堆叠,在这种情况下,默认值 "auto" 将被解释为不进行表堆叠。
  • **kwargs:传递给层基类的其他参数。

[源代码]

call 方法

DistributedEmbedding.call(
    inputs: Union[
        Any,
        Sequence[Union[Any, ForwardRef("Nested[T]")]],
        Mapping[str, Union[Any, ForwardRef("Nested[T]")]],
    ],
    weights: Union[
        Any,
        Sequence[Union[Any, ForwardRef("Nested[T]")]],
        Mapping[str, Union[Any, ForwardRef("Nested[T]")]],
        NoneType,
    ] = None,
    training: bool = False,
)

在嵌入表中查找特征并应用缩减。

参数

  • inputs:一个嵌套结构的 2D 张量,用于嵌入和缩减。结构必须与构造期间传递的 feature_configs 相同。或者,它可能包含已预处理的输入(请参阅 preprocess)。
  • weights:一个可选的嵌套结构的 2D 张量,用于在缩减之前应用权重。存在时,结构必须与 inputs 相同,并且形状必须匹配。
  • training:我们是正在训练还是评估模型。

返回

一个嵌套结构的密集 2D 张量,这些张量是来自传入特征的缩减后的嵌入。结构与 inputs 相同。


[源代码]

preprocess 方法

DistributedEmbedding.preprocess(
    inputs: Union[
        Any,
        Sequence[Union[Any, ForwardRef("Nested[T]")]],
        Mapping[str, Union[Any, ForwardRef("Nested[T]")]],
    ],
    weights: Union[
        Any,
        Sequence[Union[Any, ForwardRef("Nested[T]")]],
        Mapping[str, Union[Any, ForwardRef("Nested[T]")]],
        NoneType,
    ] = None,
    training: bool = False,
)

预处理和重构数据以供模型使用。

对于 JAX 后端,将输入数据转换为与硬件相关的格式,这是与 SparseCores 一起使用所必需的。仅当要启用 jit_compile = True 时,才需要显式调用 preprocess

对于非 JAX 后端,预处理会将输入和权重捆绑在一起,并按设备放置分离输入。此步骤是完全可选的。

参数

  • inputs:Ragged 或密集样本 ID 集合。
  • weights:可选的 ragged 或密集样本权重集合。
  • training:如果为 true,将更新内部参数,例如预处理数据所需的缓冲区大小。

返回

一组预处理后的输入,可以直接馈送到层的 inputs 参数。