KerasRS / 示例 / 使用 TPU SparseCore 和 TensorFlow 进行分布式嵌入

使用 TPU SparseCore 和 TensorFlow 进行分布式嵌入

作者: Fabien Hertschuh, Abheesht Sharma
创建日期 2025/09/02
最后修改日期 2025/09/02
描述:使用带有 SparseCore 嵌入的二维模型对电影进行排序。

在 Colab 中查看 GitHub 源代码


简介

基础排序 教程中,我们展示了如何为 MovieLens 数据集构建一个排序模型,以便向用户推荐电影。

本教程使用相同的模型和数据集进行训练,但使用了 keras_rs.layers.DistributedEmbedding,它利用了 TPU 上的 SparseCore。这是本教程的 TensorFlow 版本。它需要在 TPU v5p 或 v6e 上运行。

让我们开始安装必要的库。请注意,我们需要 tensorflow-tpu 版本 2.19。我们还将安装 keras-rs

!pip install -U -q tensorflow-tpu==2.19.1
!pip install -q keras-rs

我们正在使用 TensorFlow 的 PJRT 版本运行时。我们还启用了 MLIR 桥接。这需要在导入 TensorFlow 之前设置一些标志。

import os
import libtpu

os.environ["PJRT_DEVICE"] = "TPU"
os.environ["NEXT_PLUGGABLE_DEVICE_USE_C_API"] = "true"
os.environ["TF_PLUGGABLE_DEVICE_LIBRARY_PATH"] = libtpu.get_library_path()
os.environ["TF_XLA_FLAGS"] = (
    "--tf_mlir_enable_mlir_bridge=true "
    "--tf_mlir_enable_convert_control_to_data_outputs_pass=true "
    "--tf_mlir_enable_merge_control_flow_pass=true"
)

import tensorflow as tf
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1756859550.283774   25050 pjrt_api.cc:78] PJRT_Api is set for device type tpu
I0000 00:00:1756859550.283806   25050 pjrt_api.cc:145] The PJRT plugin has PJRT API version 0.67. The framework PJRT API version is 0.67.

现在我们将 Keras 后端设置为 TensorFlow 并导入必要的库。

os.environ["KERAS_BACKEND"] = "tensorflow"

import keras
import keras_rs
import tensorflow_datasets as tfds

创建 TPUStrategy

要在 TPU 上运行 TensorFlow,您需要使用 tf.distribute.TPUStrategy 来处理模型的分布式。

模型的核心通过 TPUStrategy 复制到 TPU 实例。请注意,在 GPU 上,您会使用 tf.distribute.MirroredStrategy,但此策略不适用于 TPU。

只有由 DistributedEmbedding 处理的嵌入表才会在所有可用 TPU 的 SparseCore 芯片之间分片。

resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu="local")
topology = tf.tpu.experimental.initialize_tpu_system(resolver)
tpu_metadata = resolver.get_tpu_system_metadata()

device_assignment = tf.tpu.experimental.DeviceAssignment.build(
    topology, num_replicas=tpu_metadata.num_cores
)
strategy = tf.distribute.TPUStrategy(
    resolver, experimental_device_assignment=device_assignment
)
INFO:tensorflow:Deallocate tpu buffers before initializing tpu system.

INFO:tensorflow:Initializing the TPU system: local

I0000 00:00:1756859550.669602   25050 next_pluggable_device_factory.cc:128] Created 1 TensorFlow NextPluggableDevices. Physical device type: TPU

INFO:tensorflow:Finished initializing TPU system.

INFO:tensorflow:Found TPU system:

INFO:tensorflow:*** Num TPU Cores: 1

INFO:tensorflow:*** Num TPU Workers: 1

INFO:tensorflow:*** Num TPU Cores Per Worker: 1

INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 0, 0)

INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:TPU:0, TPU, 0, 0)

INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 0, 0)

INFO:tensorflow:Found TPU system:

INFO:tensorflow:*** Num TPU Cores: 1

INFO:tensorflow:*** Num TPU Workers: 1

INFO:tensorflow:*** Num TPU Cores Per Worker: 1

INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 0, 0)

INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:TPU:0, TPU, 0, 0)

INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 0, 0)

数据集分布

虽然模型被复制并且嵌入表在 SparseCore 之间分片,但数据集是通过将每个批次分片到 TPU 来分发的。我们需要确保批次大小是 TPU 数量的倍数。

PER_REPLICA_BATCH_SIZE = 256
BATCH_SIZE = PER_REPLICA_BATCH_SIZE * strategy.num_replicas_in_sync

准备数据集

我们将使用相同的 MovieLens 数据。评分是我们试图预测的目标。

# Ratings data.
ratings = tfds.load("movielens/100k-ratings", split="train")
# Features of all the available movies.
movies = tfds.load("movielens/100k-movies", split="train")

我们需要知道用户的数量,因为我们将用户 ID 直接用作用户嵌入表的索引。

users_count = int(
    ratings.map(lambda x: tf.strings.to_number(x["user_id"], out_type=tf.int32))
    .reduce(tf.constant(0, tf.int32), tf.maximum)
    .numpy()
)

我们还需要知道电影的数量,因为我们将电影 ID 直接用作电影嵌入表的索引。

movies_count = int(movies.cardinality().numpy())

模型的输入是用户 ID 和电影 ID,标签是评分。

def preprocess_rating(x):
    return (
        # Inputs are user IDs and movie IDs
        {
            "user_id": tf.strings.to_number(x["user_id"], out_type=tf.int32),
            "movie_id": tf.strings.to_number(x["movie_id"], out_type=tf.int32),
        },
        # Labels are ratings between 0 and 1.
        (x["user_rating"] - 1.0) / 4.0,
    )

我们将通过将 80% 的评分放入训练集,20% 放入测试集来划分数据。

shuffled_ratings = ratings.map(preprocess_rating).shuffle(
    100_000, seed=42, reshuffle_each_iteration=False
)
train_ratings = (
    shuffled_ratings.take(80_000).batch(BATCH_SIZE, drop_remainder=True).cache()
)
test_ratings = (
    shuffled_ratings.skip(80_000)
    .take(20_000)
    .batch(BATCH_SIZE, drop_remainder=True)
    .cache()
)

配置分布式嵌入

keras_rs.layers.DistributedEmbedding 处理多个特征和多个嵌入表。这是为了实现特征之间的表共享,并允许将多个嵌入查找合并到一次调用中以获得一些优化。在本节中,我们将介绍如何配置这些。

配置表

表使用 keras_rs.layers.TableConfig 进行配置,该类包含:

  • 一个名称。
  • 一个词汇表大小(输入大小)。
  • 一个嵌入维度(输出大小)。
  • 一个组合器,用于指定在嵌入序列时如何将多个嵌入合并为一个。请注意,这不适用于我们的示例,因为我们为每个用户和每部电影获取单个嵌入。
  • 一个放置指示符,用于指定是否将表放在 SparseCore 芯片上。在这种情况下,我们想要“sparsecore”放置。
  • 一个优化器,用于指定训练时如何应用梯度。每个表都有自己的优化器,并且传递给 model.compile() 的优化器不用于嵌入表。

配置特征

特征使用 keras_rs.layers.FeatureConfig 进行配置,该类包含:

  • 一个名称。
  • 一个表,即要使用的嵌入表。
  • 一个输入形状(批次大小适用于所有 TPU)。
  • 一个输出形状(批次大小适用于所有 TPU)。

我们可以根据需要组织特征,它们可以是嵌套的。字典通常是为输入和输出命名的一种好方法。

EMBEDDING_DIMENSION = 32

movie_table = keras_rs.layers.TableConfig(
    name="movie_table",
    vocabulary_size=movies_count + 1,  # +1 for movie ID 0, which is not used
    embedding_dim=EMBEDDING_DIMENSION,
    optimizer="adam",
    placement="sparsecore",
)
user_table = keras_rs.layers.TableConfig(
    name="user_table",
    vocabulary_size=users_count + 1,  # +1 for user ID 0, which is not used
    embedding_dim=EMBEDDING_DIMENSION,
    optimizer="adam",
    placement="sparsecore",
)

FEATURE_CONFIGS = {
    "movie_id": keras_rs.layers.FeatureConfig(
        name="movie",
        table=movie_table,
        input_shape=(BATCH_SIZE,),
        output_shape=(BATCH_SIZE, EMBEDDING_DIMENSION),
    ),
    "user_id": keras_rs.layers.FeatureConfig(
        name="user",
        table=user_table,
        input_shape=(BATCH_SIZE,),
        output_shape=(BATCH_SIZE, EMBEDDING_DIMENSION),
    ),
}

定义模型

现在我们可以创建一个 DistributedEmbedding 并将其放入模型中。一旦有了配置,我们只需将其传递给 DistributedEmbedding 的构造函数。然后在模型的 call 方法中,DistributedEmbedding 是我们调用的第一个层。

输出的结构与输入完全相同。在我们的示例中,我们连接了获得的嵌入作为输出,并通过一连串的密集层运行它们。

class EmbeddingModel(keras.Model):
    """Create the model with the embedding configuration.

    Args:
        feature_configs: the configuration for `DistributedEmbedding`.
    """

    def __init__(self, feature_configs):
        super().__init__()

        self.embedding_layer = keras_rs.layers.DistributedEmbedding(
            feature_configs=feature_configs
        )
        self.ratings = keras.Sequential(
            [
                # Learn multiple dense layers.
                keras.layers.Dense(256, activation="relu"),
                keras.layers.Dense(64, activation="relu"),
                # Make rating predictions in the final layer.
                keras.layers.Dense(1),
            ]
        )

    def call(self, features):
        # Embedding lookup. Outputs have the same structure as the inputs.
        embedding = self.embedding_layer(features)
        return self.ratings(
            keras.ops.concatenate(
                [embedding["user_id"], embedding["movie_id"]],
                axis=1,
            )
        )

现在让我们实例化模型。然后我们使用 model.compile() 来配置损失、指标和优化器。同样,这个 Adagrad 优化器将仅应用于密集层,而不应用于嵌入表。

with strategy.scope():
    model = EmbeddingModel(FEATURE_CONFIGS)

    model.compile(
        loss=keras.losses.MeanSquaredError(),
        metrics=[keras.metrics.RootMeanSquaredError()],
        optimizer="adagrad",
    )
/home/fhertschuh/venv-tf219/lib/python3.10/site-packages/tensorflow/python/tpu/tpu_embedding_v3.py:406: UserWarning: MessageFactory class is deprecated. Please use GetMessageClass() instead of MessageFactory.GetPrototype. MessageFactory class will be removed after 2024.
  for layout in stacker.GetLayouts().tables:
WARNING:absl:Table movie_table_user_table is not found in max_ids_per_table provided by SparseCoreEmbeddingConfig. Using default value 256.

WARNING:absl:Table movie_table_user_table is not found in max_unique_ids_per_table provided by SparseCoreEmbeddingConfig. Using default value 256.

I0000 00:00:1756859553.965810   25050 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.

拟合并评估

我们可以使用标准的 Keras model.fit() 来训练模型。Keras 将自动使用 TPUStrategy 来分发模型和数据。

with strategy.scope():
    model.fit(train_ratings, epochs=5)
W0000 00:00:1756859554.064899   25050 auto_shard.cc:558] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed.

WARNING:absl:Outside compilation attempted outside TPUReplicateContext scope. As no enclosing TPUReplicateContext can be found, returning the result of `computation` as is.

WARNING:absl:Outside compilation attempted outside TPUReplicateContext scope. As no enclosing TPUReplicateContext can be found, returning the result of `computation` as is.

Epoch 1/5

312/312 ━━━━━━━━━━━━━━━━━━━━ 3s 2ms/step - loss: 0.1068 - root_mean_squared_error: 0.3267

Epoch 2/5

312/312 ━━━━━━━━━━━━━━━━━━━━ 1s 2ms/step - loss: 0.0560 - root_mean_squared_error: 0.2367

Epoch 3/5

312/312 ━━━━━━━━━━━━━━━━━━━━ 1s 2ms/step - loss: 0.0523 - root_mean_squared_error: 0.2287

Epoch 4/5

312/312 ━━━━━━━━━━━━━━━━━━━━ 1s 2ms/step - loss: 0.0504 - root_mean_squared_error: 0.2245

Epoch 5/5

312/312 ━━━━━━━━━━━━━━━━━━━━ 1s 2ms/step - loss: 0.0491 - root_mean_squared_error: 0.2216

model.evaluate() 也一样。

with strategy.scope():
    model.evaluate(test_ratings, return_dict=True)
W0000 00:00:1756859563.290916   25050 auto_shard.cc:558] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed.

78/78 ━━━━━━━━━━━━━━━━━━━━ 1s 3ms/step - loss: 0.0572 - root_mean_squared_error: 0.2391

就是这样。

本示例表明,在设置好 TPUStrategy 并配置好 DistributedEmbedding 后,您就可以使用标准的 Keras 工作流程了。