作者: Fabien Hertschuh, Abheesht Sharma
创建日期 2025/09/02
最后修改日期 2025/09/02
描述:使用带有 SparseCore 嵌入的二维模型对电影进行排序。
在 基础排序 教程中,我们展示了如何为 MovieLens 数据集构建一个排序模型,以便向用户推荐电影。
本教程使用相同的模型和数据集进行训练,但使用了 keras_rs.layers.DistributedEmbedding,它利用了 TPU 上的 SparseCore。这是本教程的 TensorFlow 版本。它需要在 TPU v5p 或 v6e 上运行。
让我们开始安装必要的库。请注意,我们需要 tensorflow-tpu 版本 2.19。我们还将安装 keras-rs。
!pip install -U -q tensorflow-tpu==2.19.1
!pip install -q keras-rs
我们正在使用 TensorFlow 的 PJRT 版本运行时。我们还启用了 MLIR 桥接。这需要在导入 TensorFlow 之前设置一些标志。
import os
import libtpu
os.environ["PJRT_DEVICE"] = "TPU"
os.environ["NEXT_PLUGGABLE_DEVICE_USE_C_API"] = "true"
os.environ["TF_PLUGGABLE_DEVICE_LIBRARY_PATH"] = libtpu.get_library_path()
os.environ["TF_XLA_FLAGS"] = (
"--tf_mlir_enable_mlir_bridge=true "
"--tf_mlir_enable_convert_control_to_data_outputs_pass=true "
"--tf_mlir_enable_merge_control_flow_pass=true"
)
import tensorflow as tf
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1756859550.283774 25050 pjrt_api.cc:78] PJRT_Api is set for device type tpu
I0000 00:00:1756859550.283806 25050 pjrt_api.cc:145] The PJRT plugin has PJRT API version 0.67. The framework PJRT API version is 0.67.
现在我们将 Keras 后端设置为 TensorFlow 并导入必要的库。
os.environ["KERAS_BACKEND"] = "tensorflow"
import keras
import keras_rs
import tensorflow_datasets as tfds
TPUStrategy要在 TPU 上运行 TensorFlow,您需要使用 tf.distribute.TPUStrategy 来处理模型的分布式。
模型的核心通过 TPUStrategy 复制到 TPU 实例。请注意,在 GPU 上,您会使用 tf.distribute.MirroredStrategy,但此策略不适用于 TPU。
只有由 DistributedEmbedding 处理的嵌入表才会在所有可用 TPU 的 SparseCore 芯片之间分片。
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu="local")
topology = tf.tpu.experimental.initialize_tpu_system(resolver)
tpu_metadata = resolver.get_tpu_system_metadata()
device_assignment = tf.tpu.experimental.DeviceAssignment.build(
topology, num_replicas=tpu_metadata.num_cores
)
strategy = tf.distribute.TPUStrategy(
resolver, experimental_device_assignment=device_assignment
)
INFO:tensorflow:Deallocate tpu buffers before initializing tpu system.
INFO:tensorflow:Initializing the TPU system: local
I0000 00:00:1756859550.669602 25050 next_pluggable_device_factory.cc:128] Created 1 TensorFlow NextPluggableDevices. Physical device type: TPU
INFO:tensorflow:Finished initializing TPU system.
INFO:tensorflow:Found TPU system:
INFO:tensorflow:*** Num TPU Cores: 1
INFO:tensorflow:*** Num TPU Workers: 1
INFO:tensorflow:*** Num TPU Cores Per Worker: 1
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:TPU:0, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 0, 0)
INFO:tensorflow:Found TPU system:
INFO:tensorflow:*** Num TPU Cores: 1
INFO:tensorflow:*** Num TPU Workers: 1
INFO:tensorflow:*** Num TPU Cores Per Worker: 1
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:TPU:0, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 0, 0)
虽然模型被复制并且嵌入表在 SparseCore 之间分片,但数据集是通过将每个批次分片到 TPU 来分发的。我们需要确保批次大小是 TPU 数量的倍数。
PER_REPLICA_BATCH_SIZE = 256
BATCH_SIZE = PER_REPLICA_BATCH_SIZE * strategy.num_replicas_in_sync
我们将使用相同的 MovieLens 数据。评分是我们试图预测的目标。
# Ratings data.
ratings = tfds.load("movielens/100k-ratings", split="train")
# Features of all the available movies.
movies = tfds.load("movielens/100k-movies", split="train")
我们需要知道用户的数量,因为我们将用户 ID 直接用作用户嵌入表的索引。
users_count = int(
ratings.map(lambda x: tf.strings.to_number(x["user_id"], out_type=tf.int32))
.reduce(tf.constant(0, tf.int32), tf.maximum)
.numpy()
)
我们还需要知道电影的数量,因为我们将电影 ID 直接用作电影嵌入表的索引。
movies_count = int(movies.cardinality().numpy())
模型的输入是用户 ID 和电影 ID,标签是评分。
def preprocess_rating(x):
return (
# Inputs are user IDs and movie IDs
{
"user_id": tf.strings.to_number(x["user_id"], out_type=tf.int32),
"movie_id": tf.strings.to_number(x["movie_id"], out_type=tf.int32),
},
# Labels are ratings between 0 and 1.
(x["user_rating"] - 1.0) / 4.0,
)
我们将通过将 80% 的评分放入训练集,20% 放入测试集来划分数据。
shuffled_ratings = ratings.map(preprocess_rating).shuffle(
100_000, seed=42, reshuffle_each_iteration=False
)
train_ratings = (
shuffled_ratings.take(80_000).batch(BATCH_SIZE, drop_remainder=True).cache()
)
test_ratings = (
shuffled_ratings.skip(80_000)
.take(20_000)
.batch(BATCH_SIZE, drop_remainder=True)
.cache()
)
keras_rs.layers.DistributedEmbedding 处理多个特征和多个嵌入表。这是为了实现特征之间的表共享,并允许将多个嵌入查找合并到一次调用中以获得一些优化。在本节中,我们将介绍如何配置这些。
表使用 keras_rs.layers.TableConfig 进行配置,该类包含:
model.compile() 的优化器不用于嵌入表。特征使用 keras_rs.layers.FeatureConfig 进行配置,该类包含:
我们可以根据需要组织特征,它们可以是嵌套的。字典通常是为输入和输出命名的一种好方法。
EMBEDDING_DIMENSION = 32
movie_table = keras_rs.layers.TableConfig(
name="movie_table",
vocabulary_size=movies_count + 1, # +1 for movie ID 0, which is not used
embedding_dim=EMBEDDING_DIMENSION,
optimizer="adam",
placement="sparsecore",
)
user_table = keras_rs.layers.TableConfig(
name="user_table",
vocabulary_size=users_count + 1, # +1 for user ID 0, which is not used
embedding_dim=EMBEDDING_DIMENSION,
optimizer="adam",
placement="sparsecore",
)
FEATURE_CONFIGS = {
"movie_id": keras_rs.layers.FeatureConfig(
name="movie",
table=movie_table,
input_shape=(BATCH_SIZE,),
output_shape=(BATCH_SIZE, EMBEDDING_DIMENSION),
),
"user_id": keras_rs.layers.FeatureConfig(
name="user",
table=user_table,
input_shape=(BATCH_SIZE,),
output_shape=(BATCH_SIZE, EMBEDDING_DIMENSION),
),
}
现在我们可以创建一个 DistributedEmbedding 并将其放入模型中。一旦有了配置,我们只需将其传递给 DistributedEmbedding 的构造函数。然后在模型的 call 方法中,DistributedEmbedding 是我们调用的第一个层。
输出的结构与输入完全相同。在我们的示例中,我们连接了获得的嵌入作为输出,并通过一连串的密集层运行它们。
class EmbeddingModel(keras.Model):
"""Create the model with the embedding configuration.
Args:
feature_configs: the configuration for `DistributedEmbedding`.
"""
def __init__(self, feature_configs):
super().__init__()
self.embedding_layer = keras_rs.layers.DistributedEmbedding(
feature_configs=feature_configs
)
self.ratings = keras.Sequential(
[
# Learn multiple dense layers.
keras.layers.Dense(256, activation="relu"),
keras.layers.Dense(64, activation="relu"),
# Make rating predictions in the final layer.
keras.layers.Dense(1),
]
)
def call(self, features):
# Embedding lookup. Outputs have the same structure as the inputs.
embedding = self.embedding_layer(features)
return self.ratings(
keras.ops.concatenate(
[embedding["user_id"], embedding["movie_id"]],
axis=1,
)
)
现在让我们实例化模型。然后我们使用 model.compile() 来配置损失、指标和优化器。同样,这个 Adagrad 优化器将仅应用于密集层,而不应用于嵌入表。
with strategy.scope():
model = EmbeddingModel(FEATURE_CONFIGS)
model.compile(
loss=keras.losses.MeanSquaredError(),
metrics=[keras.metrics.RootMeanSquaredError()],
optimizer="adagrad",
)
/home/fhertschuh/venv-tf219/lib/python3.10/site-packages/tensorflow/python/tpu/tpu_embedding_v3.py:406: UserWarning: MessageFactory class is deprecated. Please use GetMessageClass() instead of MessageFactory.GetPrototype. MessageFactory class will be removed after 2024.
for layout in stacker.GetLayouts().tables:
WARNING:absl:Table movie_table_user_table is not found in max_ids_per_table provided by SparseCoreEmbeddingConfig. Using default value 256.
WARNING:absl:Table movie_table_user_table is not found in max_unique_ids_per_table provided by SparseCoreEmbeddingConfig. Using default value 256.
I0000 00:00:1756859553.965810 25050 device_compiler.h:188] Compiled cluster using XLA! This line is logged at most once for the lifetime of the process.
我们可以使用标准的 Keras model.fit() 来训练模型。Keras 将自动使用 TPUStrategy 来分发模型和数据。
with strategy.scope():
model.fit(train_ratings, epochs=5)
W0000 00:00:1756859554.064899 25050 auto_shard.cc:558] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed.
WARNING:absl:Outside compilation attempted outside TPUReplicateContext scope. As no enclosing TPUReplicateContext can be found, returning the result of `computation` as is.
WARNING:absl:Outside compilation attempted outside TPUReplicateContext scope. As no enclosing TPUReplicateContext can be found, returning the result of `computation` as is.
Epoch 1/5
312/312 ━━━━━━━━━━━━━━━━━━━━ 3s 2ms/step - loss: 0.1068 - root_mean_squared_error: 0.3267
Epoch 2/5
312/312 ━━━━━━━━━━━━━━━━━━━━ 1s 2ms/step - loss: 0.0560 - root_mean_squared_error: 0.2367
Epoch 3/5
312/312 ━━━━━━━━━━━━━━━━━━━━ 1s 2ms/step - loss: 0.0523 - root_mean_squared_error: 0.2287
Epoch 4/5
312/312 ━━━━━━━━━━━━━━━━━━━━ 1s 2ms/step - loss: 0.0504 - root_mean_squared_error: 0.2245
Epoch 5/5
312/312 ━━━━━━━━━━━━━━━━━━━━ 1s 2ms/step - loss: 0.0491 - root_mean_squared_error: 0.2216
model.evaluate() 也一样。
with strategy.scope():
model.evaluate(test_ratings, return_dict=True)
W0000 00:00:1756859563.290916 25050 auto_shard.cc:558] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed.
78/78 ━━━━━━━━━━━━━━━━━━━━ 1s 3ms/step - loss: 0.0572 - root_mean_squared_error: 0.2391
就是这样。
本示例表明,在设置好 TPUStrategy 并配置好 DistributedEmbedding 后,您就可以使用标准的 Keras 工作流程了。