KerasRS / 示例 / 数据并行训练的检索

数据并行训练的检索

作者: Abheesht Sharma, Fabien Hertschuh
创建日期 2025/04/28
最后修改日期 2025/04/28
描述:使用双塔模型(数据并行训练)检索电影。

在 Colab 中查看 GitHub 源代码


简介

在本教程中,我们将训练与我们在基础检索教程中相同的检索模型,但采用分布式方式。

分布式训练用于同时在多个设备或机器上训练模型,从而缩短训练时间。在这里,我们专注于同步数据并行训练。每个加速器(GPU/TPU)都拥有模型的完整副本,并看到不同的输入数据 mini-batch。每个设备上都会计算本地梯度,然后聚合这些梯度以计算全局梯度更新。

在开始之前,让我们记下几点:

  1. 加速器的数量应大于 1。
  2. keras.distribution API 仅适用于 JAX。因此,请确保选择 JAX 作为您的后端!
!pip install -q keras-rs
import os

os.environ["KERAS_BACKEND"] = "jax"

import random

import jax
import keras
import tensorflow as tf  # Needed only for the dataset
import tensorflow_datasets as tfds

import keras_rs

数据并行

对于分布式训练中的同步数据并行策略,我们将使用 keras.distribution API 中的 DataParallel 类。

devices = jax.devices()  # Assume it has >1 local devices.
data_parallel = keras.distribution.DataParallel(devices=devices)

或者,您也可以选择使用一维 DeviceMesh 对象创建 DataParallel 对象,如下所示:

mesh_1d = keras.distribution.DeviceMesh(
    shape=(len(devices),), axis_names=["data"], devices=devices
)
data_parallel = keras.distribution.DataParallel(device_mesh=mesh_1d)
# Set the global distribution strategy.
keras.distribution.set_distribution(data_parallel)

准备数据集

现在我们已经定义了全局分布式策略,其余的指南与之前的基本检索指南完全相同。

让我们加载并准备数据集。这里我们也使用 MovieLens 数据集。

# Ratings data with user and movie data.
ratings = tfds.load("movielens/100k-ratings", split="train")
# Features of all the available movies.
movies = tfds.load("movielens/100k-movies", split="train")

# User, movie counts for defining vocabularies.
users_count = (
    ratings.map(lambda x: tf.strings.to_number(x["user_id"], out_type=tf.int32))
    .reduce(tf.constant(0, tf.int32), tf.maximum)
    .numpy()
)
movies_count = movies.cardinality().numpy()


# Preprocess dataset, and split it into train-test datasets.
def preprocess_rating(x):
    return (
        # Input is the user IDs
        tf.strings.to_number(x["user_id"], out_type=tf.int32),
        # Labels are movie IDs + ratings between 0 and 1.
        {
            "movie_id": tf.strings.to_number(x["movie_id"], out_type=tf.int32),
            "rating": (x["user_rating"] - 1.0) / 4.0,
        },
    )


shuffled_ratings = ratings.map(preprocess_rating).shuffle(
    100_000, seed=42, reshuffle_each_iteration=False
)
train_ratings = shuffled_ratings.take(80_000).batch(1000).cache()
test_ratings = shuffled_ratings.skip(80_000).take(20_000).batch(1000).cache()
WARNING:absl:Variant folder /root/tensorflow_datasets/movielens/100k-ratings/0.1.1 has no dataset_info.json

Downloading and preparing dataset Unknown size (download: Unknown size, generated: Unknown size, total: Unknown size) to /root/tensorflow_datasets/movielens/100k-ratings/0.1.1...

Dl Completed...: 0 url [00:00, ? url/s]

Dl Size...: 0 MiB [00:00, ? MiB/s]

Extraction completed...: 0 file [00:00, ? file/s]

Generating splits...:   0%|          | 0/1 [00:00<?, ? splits/s]

Generating train examples...: 0 examples [00:00, ? examples/s]

Shuffling /root/tensorflow_datasets/movielens/100k-ratings/incomplete.2O98FR_0.1.1/movielens-train.tfrecord*..…

WARNING:absl:Variant folder /root/tensorflow_datasets/movielens/100k-movies/0.1.1 has no dataset_info.json

Dataset movielens downloaded and prepared to /root/tensorflow_datasets/movielens/100k-ratings/0.1.1. Subsequent calls will reuse this data.
Downloading and preparing dataset Unknown size (download: Unknown size, generated: Unknown size, total: Unknown size) to /root/tensorflow_datasets/movielens/100k-movies/0.1.1...

Dl Completed...: 0 url [00:00, ? url/s]

Dl Size...: 0 MiB [00:00, ? MiB/s]

Extraction completed...: 0 file [00:00, ? file/s]

Generating splits...:   0%|          | 0/1 [00:00<?, ? splits/s]

Generating train examples...: 0 examples [00:00, ? examples/s]

Shuffling /root/tensorflow_datasets/movielens/100k-movies/incomplete.4QKWMO_0.1.1/movielens-train.tfrecord*...…

Dataset movielens downloaded and prepared to /root/tensorflow_datasets/movielens/100k-movies/0.1.1. Subsequent calls will reuse this data.

实现模型

我们构建一个双塔检索模型。因此,我们需要组合一个用户查询塔和一个电影候选塔。请注意,我们不必在此处更改任何内容,这与之前的基本检索教程相同。

class RetrievalModel(keras.Model):
    """Create the retrieval model with the provided parameters.

    Args:
      num_users: Number of entries in the user embedding table.
      num_candidates: Number of entries in the candidate embedding table.
      embedding_dimension: Output dimension for user and movie embedding tables.
    """

    def __init__(
        self,
        num_users,
        num_candidates,
        embedding_dimension=32,
        **kwargs,
    ):
        super().__init__(**kwargs)
        # Our query tower, simply an embedding table.
        self.user_embedding = keras.layers.Embedding(num_users, embedding_dimension)
        # Our candidate tower, simply an embedding table.
        self.candidate_embedding = keras.layers.Embedding(
            num_candidates, embedding_dimension
        )
        # The layer that performs the retrieval.
        self.retrieval = keras_rs.layers.BruteForceRetrieval(k=10, return_scores=False)
        self.loss_fn = keras.losses.MeanSquaredError()

    def build(self, input_shape):
        self.user_embedding.build(input_shape)
        self.candidate_embedding.build(input_shape)
        # In this case, the candidates are directly the movie embeddings.
        # We take a shortcut and directly reuse the variable.
        self.retrieval.candidate_embeddings = self.candidate_embedding.embeddings
        self.retrieval.build(input_shape)
        super().build(input_shape)

    def call(self, inputs, training=False):
        user_embeddings = self.user_embedding(inputs)
        result = {
            "user_embeddings": user_embeddings,
        }
        if not training:
            # Skip the retrieval of top movies during training as the
            # predictions are not used.
            result["predictions"] = self.retrieval(user_embeddings)
        return result

    def compute_loss(self, x, y, y_pred, sample_weight, training=True):
        candidate_id, rating = y["movie_id"], y["rating"]
        user_embeddings = y_pred["user_embeddings"]
        candidate_embeddings = self.candidate_embedding(candidate_id)

        labels = keras.ops.expand_dims(rating, -1)
        # Compute the affinity score by multiplying the two embeddings.
        scores = keras.ops.sum(
            keras.ops.multiply(user_embeddings, candidate_embeddings),
            axis=1,
            keepdims=True,
        )
        return self.loss_fn(labels, scores, sample_weight)

拟合并评估

定义模型后,我们可以使用标准的 Keras model.fit() 来训练和评估模型。

model = RetrievalModel(users_count + 1, movies_count + 1)
model.compile(optimizer=keras.optimizers.Adagrad(learning_rate=0.2))

让我们训练模型。评估需要一些时间,所以我们每 5 个 epoch 才评估一次模型。

history = model.fit(
    train_ratings, validation_data=test_ratings, validation_freq=5, epochs=50
)
Epoch 1/50

80/80 ━━━━━━━━━━━━━━━━━━━━ 3s 8ms/step - loss: 0.4772

Epoch 2/50

80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.4771

Epoch 3/50

80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.4770

Epoch 4/50

80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.4769

Epoch 5/50

80/80 ━━━━━━━━━━━━━━━━━━━━ 3s 37ms/step - loss: 0.4769 - val_loss: 0.4836

Epoch 6/50

80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - loss: 0.4768

Epoch 7/50

80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.4767

Epoch 8/50

80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.4766

Epoch 9/50

80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - loss: 0.4764

Epoch 10/50

80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - loss: 0.4763 - val_loss: 0.4833

Epoch 11/50

80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.4761

Epoch 12/50

80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.4759

Epoch 13/50

80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.4757

Epoch 14/50

80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.4754

Epoch 15/50

80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - loss: 0.4750 - val_loss: 0.4821

Epoch 16/50

80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - loss: 0.4746

Epoch 17/50

80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - loss: 0.4740

Epoch 18/50

80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.4734

Epoch 19/50

80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.4725

Epoch 20/50

80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - loss: 0.4715 - val_loss: 0.4784

Epoch 21/50

80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.4702

Epoch 22/50

80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.4686

Epoch 23/50

80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.4666

Epoch 24/50

80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.4641

Epoch 25/50

80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - loss: 0.4609 - val_loss: 0.4664

Epoch 26/50

80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.4571

Epoch 27/50

80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.4524

Epoch 28/50

80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.4466

Epoch 29/50

80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.4395

Epoch 30/50

80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - loss: 0.4311 - val_loss: 0.4326

Epoch 31/50

80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.4210

Epoch 32/50

80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.4093

Epoch 33/50

80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.3957

Epoch 34/50

80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.3805

Epoch 35/50

80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - loss: 0.3636 - val_loss: 0.3597

Epoch 36/50

80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.3455

Epoch 37/50

80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.3265

Epoch 38/50

80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.3072

Epoch 39/50

80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.2880

Epoch 40/50

80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - loss: 0.2696 - val_loss: 0.2664

Epoch 41/50

80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.2523

Epoch 42/50

80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.2363

Epoch 43/50

80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.2218

Epoch 44/50

80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.2087

Epoch 45/50

80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - loss: 0.1970 - val_loss: 0.1986

Epoch 46/50

80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.1866

Epoch 47/50

80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.1773

Epoch 48/50

80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.1689

Epoch 49/50

80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.1613

Epoch 50/50

80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - loss: 0.1544 - val_loss: 0.1586

进行预测

现在我们有了模型,让我们运行推理并进行预测。

movie_id_to_movie_title = {
    int(x["movie_id"]): x["movie_title"] for x in movies.as_numpy_iterator()
}
movie_id_to_movie_title[0] = ""  # Because id 0 is not in the dataset.

然后我们只需使用 Keras model.predict() 方法。在底层,它调用 BruteForceRetrieval 层来执行实际的检索。

user_ids = random.sample(range(1, 1001), len(devices))
predictions = model.predict(keras.ops.convert_to_tensor(user_ids))
predictions = keras.ops.convert_to_numpy(predictions["predictions"])

for i, user_id in enumerate(user_ids):
    print(f"\n==Recommended movies for user {user_id}==")
    for movie_id in predictions[i]:
        print(movie_id_to_movie_title[movie_id])
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 211ms/step

==Recommended movies for user 793==
b'Star Wars (1977)'
b'Godfather, The (1972)'
b'Raiders of the Lost Ark (1981)'
b'Fargo (1996)'
b'Silence of the Lambs, The (1991)'
b"Schindler's List (1993)"
b'Shawshank Redemption, The (1994)'
b'Titanic (1997)'
b'Braveheart (1995)'
b'Pulp Fiction (1994)'

==Recommended movies for user 188==
b'Star Wars (1977)'
b'Fargo (1996)'
b'Godfather, The (1972)'
b'Silence of the Lambs, The (1991)'
b"Schindler's List (1993)"
b'Return of the Jedi (1983)'
b'Raiders of the Lost Ark (1981)'
b'Pulp Fiction (1994)'
b'Toy Story (1995)'
b'Empire Strikes Back, The (1980)'

==Recommended movies for user 865==
b'Star Wars (1977)'
b'Fargo (1996)'
b'Godfather, The (1972)'
b'Silence of the Lambs, The (1991)'
b'Raiders of the Lost Ark (1981)'
b"Schindler's List (1993)"
b'Return of the Jedi (1983)'
b'Shawshank Redemption, The (1994)'
b'Pulp Fiction (1994)'
b'Empire Strikes Back, The (1980)'

==Recommended movies for user 710==
b'Star Wars (1977)'
b'Fargo (1996)'
b'Godfather, The (1972)'
b'Silence of the Lambs, The (1991)'
b'Raiders of the Lost Ark (1981)'
b"Schindler's List (1993)"
b'Pulp Fiction (1994)'
b'Return of the Jedi (1983)'
b'Empire Strikes Back, The (1980)'
b'Toy Story (1995)'

==Recommended movies for user 721==
b'Star Wars (1977)'
b'Fargo (1996)'
b'Godfather, The (1972)'
b'Raiders of the Lost Ark (1981)'
b'Silence of the Lambs, The (1991)'
b"Schindler's List (1993)"
b'Return of the Jedi (1983)'
b'Empire Strikes Back, The (1980)'
b'Pulp Fiction (1994)'
b'Casablanca (1942)'

==Recommended movies for user 451==
b'Star Wars (1977)'
b'Raiders of the Lost Ark (1981)'
b'Godfather, The (1972)'
b'Fargo (1996)'
b'Silence of the Lambs, The (1991)'
b'Return of the Jedi (1983)'
b'Contact (1997)'
b'Casablanca (1942)'
b'Empire Strikes Back, The (1980)'
b'Pulp Fiction (1994)'

==Recommended movies for user 228==
b'Star Wars (1977)'
b'Fargo (1996)'
b'Godfather, The (1972)'
b'Raiders of the Lost Ark (1981)'
b'Silence of the Lambs, The (1991)'
b"Schindler's List (1993)"
b'Return of the Jedi (1983)'
b'Pulp Fiction (1994)'
b'Empire Strikes Back, The (1980)'
b'Shawshank Redemption, The (1994)'

==Recommended movies for user 175==
b'Star Wars (1977)'
b'Fargo (1996)'
b'Silence of the Lambs, The (1991)'
b'Raiders of the Lost Ark (1981)'
b'Return of the Jedi (1983)'
b'Casablanca (1942)'
b"Schindler's List (1993)"
b'Empire Strikes Back, The (1980)'
b'Godfather, The (1972)'
b"One Flew Over the Cuckoo's Nest (1975)"

我们完成了!对于数据并行训练,我们所要做的就是添加大约 3-5 行代码。其余部分完全相同。