KerasRS / 示例 / 使用 SASRec 进行序列检索

使用 SASRec 进行序列检索

作者: Abheesht Sharma, Fabien Hertschuh
创建日期 2025/04/28
最后修改日期 2025/04/28
描述: 使用基于 Transformer 的检索模型 (SASRec) 推荐电影。

在 Colab 中查看 GitHub 源代码


引言

序列推荐是一种流行的模型,它查看用户先前互动过的一系列项目,然后预测下一个项目。在这里,每个序列中项目的顺序很重要。之前,在推荐电影:使用序列模型进行检索示例中,我们构建了一个基于 GRU 的序列检索模型。在本示例中,我们将针对相同的序列推荐任务构建一个流行的基于 Transformer 解码器的模型,名为自注意力序列推荐 (Self-Attentive Sequential Recommendation - SASRec)

首先,让我们导入所有必要的库。

!pip install -q keras-rs
import os

os.environ["KERAS_BACKEND"] = "jax"  # `"tensorflow"`/`"torch"`

import collections
import os

import keras
import keras_hub
import numpy as np
import pandas as pd
import tensorflow as tf  # Needed only for the dataset
from keras import ops

import keras_rs

接下来,让我们定义下面所有重要的变量/超参数。

DATA_DIR = "./raw/data/"

# MovieLens-specific variables
MOVIELENS_1M_URL = "https://files.grouplens.org/datasets/movielens/ml-1m.zip"
MOVIELENS_ZIP_HASH = "a6898adb50b9ca05aa231689da44c217cb524e7ebd39d264c56e2832f2c54e20"

RATINGS_FILE_NAME = "ratings.dat"
MOVIES_FILE_NAME = "movies.dat"

# Data processing args
MAX_CONTEXT_LENGTH = 200
MIN_SEQUENCE_LENGTH = 3
PAD_ITEM_ID = 0

RATINGS_DATA_COLUMNS = ["UserID", "MovieID", "Rating", "Timestamp"]
MOVIES_DATA_COLUMNS = ["MovieID", "Title", "Genres"]
MIN_RATING = 2

# Training/model args picked from SASRec paper
BATCH_SIZE = 128
NUM_EPOCHS = 10
LEARNING_RATE = 0.001

NUM_LAYERS = 2
NUM_HEADS = 1
HIDDEN_DIM = 50
DROPOUT = 0.2

数据集

接下来,我们需要准备我们的数据集。就像我们在序列检索示例中做的那样,我们将使用 MovieLens 数据集。

数据集准备步骤相当复杂。原始评分数据集包含 (用户, 电影 ID, 评分, 时间戳) 元组(以及其他不重要的列)。由于我们处理的是序列检索,我们需要为每个用户创建电影序列,其中序列按时间戳排序。

让我们先下载并读取数据集。

# Download the MovieLens dataset.
if not os.path.exists(DATA_DIR):
    os.makedirs(DATA_DIR)

path_to_zip = keras.utils.get_file(
    fname="ml-1m.zip",
    origin=MOVIELENS_1M_URL,
    file_hash=MOVIELENS_ZIP_HASH,
    hash_algorithm="sha256",
    extract=True,
    cache_dir=DATA_DIR,
)
movielens_extracted_dir = os.path.join(
    os.path.dirname(path_to_zip),
    "ml-1m_extracted",
    "ml-1m",
)


# Read the dataset.
def read_data(data_directory, min_rating=None):
    """Read movielens ratings.dat and movies.dat file
    into dataframe.
    """

    ratings_df = pd.read_csv(
        os.path.join(data_directory, RATINGS_FILE_NAME),
        sep="::",
        names=RATINGS_DATA_COLUMNS,
        encoding="unicode_escape",
    )
    ratings_df["Timestamp"] = ratings_df["Timestamp"].apply(int)

    # Remove movies with `rating < min_rating`.
    if min_rating is not None:
        ratings_df = ratings_df[ratings_df["Rating"] >= min_rating]

    movies_df = pd.read_csv(
        os.path.join(data_directory, MOVIES_FILE_NAME),
        sep="::",
        names=MOVIES_DATA_COLUMNS,
        encoding="unicode_escape",
    )
    return ratings_df, movies_df


ratings_df, movies_df = read_data(
    data_directory=movielens_extracted_dir, min_rating=MIN_RATING
)

# Need to know #movies so as to define embedding layers.
movies_count = movies_df["MovieID"].max()
Downloading data from https://files.grouplens.org/datasets/movielens/ml-1m.zip
5917549/5917549 ━━━━━━━━━━━━━━━━━━━━ 0s 0us/step

/var/tmp/ipykernel_686076/1372663084.py:26: ParserWarning: Falling back to the 'python' engine because the 'c' engine does not support regex separators (separators > 1 char and different from '\s+' are interpreted as regex); you can avoid this warning by specifying engine='python'.
  ratings_df = pd.read_csv(
/var/tmp/ipykernel_686076/1372663084.py:38: ParserWarning: Falling back to the 'python' engine because the 'c' engine does not support regex separators (separators > 1 char and different from '\s+' are interpreted as regex); you can avoid this warning by specifying engine='python'.
  movies_df = pd.read_csv(

读取数据集后,接下来让我们为每个用户创建电影序列。这是实现该功能的函数。

def get_movie_sequence_per_user(ratings_df):
    """Get movieID sequences for every user."""
    sequences = collections.defaultdict(list)

    for user_id, movie_id, rating, timestamp in ratings_df.values:
        sequences[user_id].append(
            {
                "movie_id": movie_id,
                "timestamp": timestamp,
                "rating": rating,
            }
        )

    # Sort movie sequences by timestamp for every user.
    for user_id, context in sequences.items():
        context.sort(key=lambda x: x["timestamp"])
        sequences[user_id] = context

    return sequences


sequences = get_movie_sequence_per_user(ratings_df)

到目前为止,我们基本上重复了在序列检索示例中所做的工作。我们为每个用户创建了一个电影序列。

SASRec 采用对比学习方式进行训练,这意味着模型学习区分用户实际互动过的电影序列(正样本)和他们未互动过的序列(负样本)。

下面的函数 format_data 以这种特定格式准备数据。对于每个用户的电影序列,它会生成一个相应的“负序列”。这个负序列由用户未互动过的随机选择的电影组成,其长度与原始序列相同。

def format_data(sequences):
    examples = {
        "sequence": [],
        "negative_sequence": [],
    }

    for user_id in sequences:
        sequence = [int(d["movie_id"]) for d in sequences[user_id]]

        # Get negative sequence.
        def random_negative_item_id(low, high, positive_lst):
            sampled = np.random.randint(low=low, high=high)
            while sampled in positive_lst:
                sampled = np.random.randint(low=low, high=high)
            return sampled

        negative_sequence = [
            random_negative_item_id(1, movies_count + 1, sequence)
            for _ in range(len(sequence))
        ]

        examples["sequence"].append(np.array(sequence))
        examples["negative_sequence"].append(np.array(negative_sequence))

    examples["sequence"] = tf.ragged.constant(examples["sequence"])
    examples["negative_sequence"] = tf.ragged.constant(examples["negative_sequence"])

    return examples


examples = format_data(sequences)
ds = tf.data.Dataset.from_tensor_slices(examples).batch(BATCH_SIZE)

现在我们有了每个用户的原始电影互动序列(来自 format_data,存储在 examples["sequence"] 中)以及它们对应的随机负序列(在 examples["negative_sequence"] 中),下一步是准备这些数据作为模型的输入。此预处理的主要目标是

  1. 创建输入特征和目标标签:对于序列推荐,模型学习在给定前一个项目的情况下预测序列中的下一个项目。这是通过以下方式实现的:- 获取原始的 example["sequence"],并从所有项目(除了最后一个,即 example["sequence"][..., :-1])创建模型的输入特征(item_ids);- 通过获取原始的 example["sequence"] 并进行移位,使用所有项目(除了第一个,即 example["sequence"][..., 1:])创建目标“正序列”(模型尝试预测为实际下一个项目);- 移位 example["negative_sequence"](来自 format_data)是为了为对比损失创建目标“负序列”(example["negative_sequence"][..., 1:])。

  2. 处理变长序列:神经网络通常需要固定大小的输入。因此,输入特征序列和目标序列都将被填充(使用特殊的 PAD_ITEM_ID)或截断到预定义的 MAX_CONTEXT_LENGTH。还从输入特征生成一个 padding_mask,以确保模型在注意力计算过程中忽略这些填充的 token,即这些 token 将被掩码。

  3. 区分训练集和验证/测试集:- 在训练期间:- 输入特征(item_ids)和负序列的上下文按照上述方式准备(原始序列中除了最后一个项目之外的所有项目)。- 目标正序列和负序列是原始序列的移位版本。- 基于输入特征创建 sample_weight,以确保损失只在实际项目上计算,而不在目标中的填充 token 上计算。- 在验证/测试期间:- 输入特征的准备方式类似。- 模型的性能通常通过其预测原始序列中实际最后一个项目的能力来评估。因此,sample_weight 被配置为仅将损失计算集中在目标序列中的这个最终预测上。

注意:SASRec 的做法与我们上述类似,但验证集使用 item_ids[:-2],测试集使用 item_ids[:-1]。为了简洁起见,我们在本示例中省略了这一步。

def _preprocess(example, train=False):
    sequence = example["sequence"]
    negative_sequence = example["negative_sequence"]

    if train:
        sequence = example["sequence"][..., :-1]
        negative_sequence = example["negative_sequence"][..., :-1]

    batch_size = tf.shape(sequence)[0]

    if not train:
        # Loss computed only on last token.
        sample_weight = tf.zeros_like(sequence, dtype="float32")[..., :-1]
        sample_weight = tf.concat(
            [sample_weight, tf.ones((batch_size, 1), dtype="float32")], axis=1
        )

    # Truncate/pad sequence. +1 to account for truncation later.
    sequence = sequence.to_tensor(
        shape=[batch_size, MAX_CONTEXT_LENGTH + 1], default_value=PAD_ITEM_ID
    )
    negative_sequence = negative_sequence.to_tensor(
        shape=[batch_size, MAX_CONTEXT_LENGTH + 1], default_value=PAD_ITEM_ID
    )
    if train:
        sample_weight = tf.cast(sequence != PAD_ITEM_ID, dtype="float32")
    else:
        sample_weight = sample_weight.to_tensor(
            shape=[batch_size, MAX_CONTEXT_LENGTH + 1], default_value=0
        )

    example = (
        {
            # last token does not have a next token
            "item_ids": sequence[..., :-1],
            # padding mask for controlling attention mask
            "padding_mask": (sequence != PAD_ITEM_ID)[..., :-1],
        },
        {
            "positive_sequence": sequence[
                ..., 1:
            ],  # 0th token's label will be 1st token, and so on
            "negative_sequence": negative_sequence[..., 1:],
        },
        sample_weight[..., 1:],  # loss will not be computed on pad tokens
    )
    return example


def preprocess_train(examples):
    return _preprocess(examples, train=True)


def preprocess_val(examples):
    return _preprocess(examples, train=False)


train_ds = ds.map(preprocess_train)
val_ds = ds.map(preprocess_val)

我们可以查看每个批次的数据。

for batch in train_ds.take(1):
    print(batch)

for batch in val_ds.take(1):
    print(batch)
({'item_ids': <tf.Tensor: shape=(128, 200), dtype=int32, numpy=
array([[3186, 1270, 1721, ...,    0,    0,    0],
       [1198, 1210, 1217, ...,    0,    0,    0],
       [ 593, 2858, 3534, ...,    0,    0,    0],
       ...,
       [ 902, 1179, 1210, ...,    0,    0,    0],
       [1270, 3252, 1476, ...,    0,    0,    0],
       [2253, 3073, 1968, ...,    0,    0,    0]], dtype=int32)>, 'padding_mask': <tf.Tensor: shape=(128, 200), dtype=bool, numpy=
array([[ True,  True,  True, ..., False, False, False],
       [ True,  True,  True, ..., False, False, False],
       [ True,  True,  True, ..., False, False, False],
       ...,
       [ True,  True,  True, ..., False, False, False],
       [ True,  True,  True, ..., False, False, False],
       [ True,  True,  True, ..., False, False, False]])>}, {'positive_sequence': <tf.Tensor: shape=(128, 200), dtype=int32, numpy=
array([[1270, 1721, 1022, ...,    0,    0,    0],
       [1210, 1217, 2717, ...,    0,    0,    0],
       [2858, 3534, 1968, ...,    0,    0,    0],
       ...,
       [1179, 1210, 3868, ...,    0,    0,    0],
       [3252, 1476,  260, ...,    0,    0,    0],
       [3073, 1968,  852, ...,    0,    0,    0]], dtype=int32)>, 'negative_sequence': <tf.Tensor: shape=(128, 200), dtype=int32, numpy=
array([[2500, 2682, 3621, ...,    0,    0,    0],
       [ 204,  450, 3339, ...,    0,    0,    0],
       [2452,  133, 2363, ...,    0,    0,    0],
       ...,
       [1935, 2507, 2009, ...,    0,    0,    0],
       [1663, 2644, 2326, ...,    0,    0,    0],
       [1273, 3577,  441, ...,    0,    0,    0]], dtype=int32)>}, <tf.Tensor: shape=(128, 200), dtype=float32, numpy=
array([[1., 1., 1., ..., 0., 0., 0.],
       [1., 1., 1., ..., 0., 0., 0.],
       [1., 1., 1., ..., 0., 0., 0.],
       ...,
       [1., 1., 1., ..., 0., 0., 0.],
       [1., 1., 1., ..., 0., 0., 0.],
       [1., 1., 1., ..., 0., 0., 0.]], dtype=float32)>)
({'item_ids': <tf.Tensor: shape=(128, 200), dtype=int32, numpy=
array([[3186, 1270, 1721, ...,    0,    0,    0],
       [1198, 1210, 1217, ...,    0,    0,    0],
       [ 593, 2858, 3534, ...,    0,    0,    0],
       ...,
       [ 902, 1179, 1210, ...,    0,    0,    0],
       [1270, 3252, 1476, ...,    0,    0,    0],
       [2253, 3073, 1968, ...,    0,    0,    0]], dtype=int32)>, 'padding_mask': <tf.Tensor: shape=(128, 200), dtype=bool, numpy=
array([[ True,  True,  True, ..., False, False, False],
       [ True,  True,  True, ..., False, False, False],
       [ True,  True,  True, ..., False, False, False],
       ...,
       [ True,  True,  True, ..., False, False, False],
       [ True,  True,  True, ..., False, False, False],
       [ True,  True,  True, ..., False, False, False]])>}, {'positive_sequence': <tf.Tensor: shape=(128, 200), dtype=int32, numpy=
array([[1270, 1721, 1022, ...,    0,    0,    0],
       [1210, 1217, 2717, ...,    0,    0,    0],
       [2858, 3534, 1968, ...,    0,    0,    0],
       ...,
       [1179, 1210, 3868, ...,    0,    0,    0],
       [3252, 1476,  260, ...,    0,    0,    0],
       [3073, 1968,  852, ...,    0,    0,    0]], dtype=int32)>, 'negative_sequence': <tf.Tensor: shape=(128, 200), dtype=int32, numpy=
array([[2500, 2682, 3621, ...,    0,    0,    0],
       [ 204,  450, 3339, ...,    0,    0,    0],
       [2452,  133, 2363, ...,    0,    0,    0],
       ...,
       [1935, 2507, 2009, ...,    0,    0,    0],
       [1663, 2644, 2326, ...,    0,    0,    0],
       [1273, 3577,  441, ...,    0,    0,    0]], dtype=int32)>}, <tf.Tensor: shape=(128, 200), dtype=float32, numpy=
array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]], dtype=float32)>)

模型

为了编码输入序列,我们使用一个基于 Transformer 解码器的模型。模型的这部分与 GPT-2 架构非常相似。有关此部分的更多详细信息,请参阅使用 KerasHub 从头开始进行 GPT 文本生成指南。

需要注意的一点是,当我们进行“预测”时,即 trainingFalse 时,我们获取的是序列中最后一个电影对应的嵌入。这是有意义的,因为在推理时,我们希望预测用户在看完最后一个电影后可能观看的电影。

此外,有必要讨论一下 compute_loss 方法。我们使用输入嵌入矩阵对正序列和负序列进行嵌入。我们通过计算点积来计算(正序列,输入序列)和(负序列,输入序列)对嵌入的相似度。现在目标是最大化前者(正序列)的相似度并最小化后者(负序列)的相似度。让我们从数学角度来看。二元交叉熵(Binary Cross Entropy)的公式如下:

 loss = - (y_true * np.log(y_pred) + (1 - y_true) * np.log(1 - y_pred))

在这里,我们将正样本对的标签设为 1,负样本对的标签设为 0。因此,对于正样本对,损失简化为

loss = -np.log(positive_logits)

最小化损失意味着我们希望最大化 log 项,这反过来意味着最大化 positive_logits。类似地,我们希望最小化 negative_logits

class SasRec(keras.Model):
    def __init__(
        self,
        vocabulary_size,
        num_layers,
        num_heads,
        hidden_dim,
        dropout=0.0,
        max_sequence_length=100,
        dtype=None,
        **kwargs,
    ):
        super().__init__(dtype=dtype, **kwargs)

        # ======== Layers ========

        # === Embeddings ===
        self.item_embedding = keras_hub.layers.ReversibleEmbedding(
            input_dim=vocabulary_size,
            output_dim=hidden_dim,
            embeddings_initializer="glorot_uniform",
            embeddings_regularizer=keras.regularizers.l2(0.001),
            dtype=dtype,
            name="item_embedding",
        )
        self.position_embedding = keras_hub.layers.PositionEmbedding(
            initializer="glorot_uniform",
            sequence_length=max_sequence_length,
            dtype=dtype,
            name="position_embedding",
        )
        self.embeddings_add = keras.layers.Add(
            dtype=dtype,
            name="embeddings_add",
        )
        self.embeddings_dropout = keras.layers.Dropout(
            dropout,
            dtype=dtype,
            name="embeddings_dropout",
        )

        # === Decoder layers ===
        self.transformer_layers = []
        for i in range(num_layers):
            self.transformer_layers.append(
                keras_hub.layers.TransformerDecoder(
                    intermediate_dim=hidden_dim,
                    num_heads=num_heads,
                    dropout=dropout,
                    layer_norm_epsilon=1e-05,
                    # SASRec uses ReLU, although GeLU might be a better option
                    activation="relu",
                    kernel_initializer="glorot_uniform",
                    normalize_first=True,
                    dtype=dtype,
                    name=f"transformer_layer_{i}",
                )
            )

        # === Final layer norm ===
        self.layer_norm = keras.layers.LayerNormalization(
            axis=-1,
            epsilon=1e-8,
            dtype=dtype,
            name="layer_norm",
        )

        # === Retrieval ===
        # The layer that performs the retrieval.
        self.retrieval = keras_rs.layers.BruteForceRetrieval(k=10, return_scores=False)

        # === Loss ===
        self.loss_fn = keras.losses.BinaryCrossentropy(from_logits=True, reduction=None)

        # === Attributes ===
        self.vocabulary_size = vocabulary_size
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.hidden_dim = hidden_dim
        self.dropout = dropout
        self.max_sequence_length = max_sequence_length

    def _get_last_non_padding_token(self, tensor, padding_mask):
        valid_token_mask = ops.logical_not(padding_mask)
        seq_lengths = ops.sum(ops.cast(valid_token_mask, "int32"), axis=1)
        last_token_indices = ops.maximum(seq_lengths - 1, 0)

        indices = ops.expand_dims(last_token_indices, axis=(-2, -1))
        gathered_tokens = ops.take_along_axis(tensor, indices, axis=1)
        last_token_embedding = ops.squeeze(gathered_tokens, axis=1)

        return last_token_embedding

    def build(self, input_shape):
        embedding_shape = list(input_shape) + [self.hidden_dim]

        # Model
        self.item_embedding.build(input_shape)
        self.position_embedding.build(embedding_shape)

        self.embeddings_add.build((embedding_shape, embedding_shape))
        self.embeddings_dropout.build(embedding_shape)

        for transformer_layer in self.transformer_layers:
            transformer_layer.build(decoder_sequence_shape=embedding_shape)

        self.layer_norm.build(embedding_shape)

        # Retrieval
        self.retrieval.candidate_embeddings = self.item_embedding.embeddings
        self.retrieval.build(input_shape)

        # Chain to super
        super().build(input_shape)

    def call(self, inputs, training=False):
        item_ids, padding_mask = inputs["item_ids"], inputs["padding_mask"]

        x = self.item_embedding(item_ids)
        position_embedding = self.position_embedding(x)
        x = self.embeddings_add((x, position_embedding))
        x = self.embeddings_dropout(x)

        for transformer_layer in self.transformer_layers:
            x = transformer_layer(x, decoder_padding_mask=padding_mask)

        item_sequence_embedding = self.layer_norm(x)
        result = {"item_sequence_embedding": item_sequence_embedding}

        # At inference, perform top-k retrieval.
        if not training:
            # need to extract last non-padding token.
            last_item_embedding = self._get_last_non_padding_token(
                item_sequence_embedding, padding_mask
            )
            result["predictions"] = self.retrieval(last_item_embedding)

        return result

    def compute_loss(self, x, y, y_pred, sample_weight, training=False):
        item_sequence_embedding = y_pred["item_sequence_embedding"]
        y_positive_sequence = y["positive_sequence"]
        y_negative_sequence = y["negative_sequence"]

        # Embed positive, negative sequences.
        positive_sequence_embedding = self.item_embedding(y_positive_sequence)
        negative_sequence_embedding = self.item_embedding(y_negative_sequence)

        # Logits
        positive_logits = ops.sum(
            ops.multiply(positive_sequence_embedding, item_sequence_embedding),
            axis=-1,
        )
        negative_logits = ops.sum(
            ops.multiply(negative_sequence_embedding, item_sequence_embedding),
            axis=-1,
        )
        logits = ops.concatenate([positive_logits, negative_logits], axis=1)

        # Labels
        labels = ops.concatenate(
            [
                ops.ones_like(positive_logits),
                ops.zeros_like(negative_logits),
            ],
            axis=1,
        )

        # sample weights
        sample_weight = ops.concatenate(
            [sample_weight, sample_weight],
            axis=1,
        )

        loss = self.loss_fn(
            y_true=ops.expand_dims(labels, axis=-1),
            y_pred=ops.expand_dims(logits, axis=-1),
            sample_weight=sample_weight,
        )
        loss = ops.divide_no_nan(ops.sum(loss), ops.sum(sample_weight))

        return loss

    def compute_output_shape(self, inputs_shape):
        return list(inputs_shape) + [self.hidden_dim]

让我们实例化模型并进行一些健全性检查。

model = SasRec(
    vocabulary_size=movies_count + 1,
    num_layers=NUM_LAYERS,
    num_heads=NUM_HEADS,
    hidden_dim=HIDDEN_DIM,
    dropout=DROPOUT,
    max_sequence_length=MAX_CONTEXT_LENGTH,
)

# Training
output = model(
    inputs={
        "item_ids": ops.ones((2, MAX_CONTEXT_LENGTH), dtype="int32"),
        "padding_mask": ops.ones((2, MAX_CONTEXT_LENGTH), dtype="bool"),
    },
    training=True,
)
print(output["item_sequence_embedding"].shape)

# Inference
output = model(
    inputs={
        "item_ids": ops.ones((2, MAX_CONTEXT_LENGTH), dtype="int32"),
        "padding_mask": ops.ones((2, MAX_CONTEXT_LENGTH), dtype="bool"),
    },
    training=False,
)
print(output["predictions"].shape)
(2, 200, 50)
(2, 10)

现在,让我们编译和训练我们的模型。

model.compile(
    optimizer=keras.optimizers.Adam(learning_rate=LEARNING_RATE, beta_2=0.98),
)
model.fit(
    x=train_ds,
    validation_data=val_ds,
    epochs=NUM_EPOCHS,
)
Epoch 1/10
48/48 ━━━━━━━━━━━━━━━━━━━━ 13s 191ms/step - loss: 0.6054 - val_loss: 0.5092
Epoch 2/10
48/48 ━━━━━━━━━━━━━━━━━━━━ 3s 9ms/step - loss: 0.4463 - val_loss: 0.5017
Epoch 3/10
48/48 ━━━━━━━━━━━━━━━━━━━━ 0s 9ms/step - loss: 0.4340 - val_loss: 0.4836
Epoch 4/10
48/48 ━━━━━━━━━━━━━━━━━━━━ 0s 9ms/step - loss: 0.4210 - val_loss: 0.4703
Epoch 5/10
48/48 ━━━━━━━━━━━━━━━━━━━━ 0s 9ms/step - loss: 0.4030 - val_loss: 0.4510
Epoch 6/10
48/48 ━━━━━━━━━━━━━━━━━━━━ 0s 9ms/step - loss: 0.3758 - val_loss: 0.4285
Epoch 7/10
48/48 ━━━━━━━━━━━━━━━━━━━━ 0s 9ms/step - loss: 0.3515 - val_loss: 0.4096
Epoch 8/10
48/48 ━━━━━━━━━━━━━━━━━━━━ 0s 9ms/step - loss: 0.3311 - val_loss: 0.3948
Epoch 9/10
48/48 ━━━━━━━━━━━━━━━━━━━━ 0s 9ms/step - loss: 0.3148 - val_loss: 0.3850
Epoch 10/10
48/48 ━━━━━━━━━━━━━━━━━━━━ 0s 9ms/step - loss: 0.3024 - val_loss: 0.3778

<keras.src.callbacks.history.History at 0x7f75255fe9e0>

进行预测

现在我们有了模型,我们希望能够进行预测。

到目前为止,我们只处理了电影 ID。现在是时候创建一个以电影 ID 为键的映射,以便能够显示电影标题。

movie_id_to_movie_title = dict(zip(movies_df["MovieID"], movies_df["Title"]))
movie_id_to_movie_title[0] = ""  # Because id 0 is not in the dataset.

然后,我们简单地使用 Keras 的 model.predict() 方法。在底层,它调用 BruteForceRetrieval 层来执行实际的检索。

请注意,此模型可以检索用户已经观看过的电影。如果需要,我们可以轻松添加逻辑来移除它们。

for ele in val_ds.unbatch().take(1):
    test_sample = ele[0]
    test_sample["item_ids"] = tf.expand_dims(test_sample["item_ids"], axis=0)
    test_sample["padding_mask"] = tf.expand_dims(test_sample["padding_mask"], axis=0)

movie_sequence = np.array(test_sample["item_ids"])[0]
for movie_id in movie_sequence:
    if movie_id == 0:
        continue
    print(movie_id_to_movie_title[movie_id], end="; ")
print()

predictions = model.predict(test_sample)["predictions"]
predictions = keras.ops.convert_to_numpy(predictions)

for movie_id in predictions[0]:
    print(movie_id_to_movie_title[movie_id])
Girl, Interrupted (1999); Back to the Future (1985); Titanic (1997); Cinderella (1950); Meet Joe Black (1998); Last Days of Disco, The (1998); Erin Brockovich (2000); Christmas Story, A (1983); To Kill a Mockingbird (1962); One Flew Over the Cuckoo's Nest (1975); Wallace & Gromit: The Best of Aardman Animation (1996); Star Wars: Episode IV - A New Hope (1977); Wizard of Oz, The (1939); Fargo (1996); Run Lola Run (Lola rennt) (1998); Rain Man (1988); Saving Private Ryan (1998); Awakenings (1990); Gigi (1958); Sound of Music, The (1965); Driving Miss Daisy (1989); Bambi (1942); Apollo 13 (1995); Mary Poppins (1964); E.T. the Extra-Terrestrial (1982); My Fair Lady (1964); Ben-Hur (1959); Big (1988); Sixth Sense, The (1999); Dead Poets Society (1989); James and the Giant Peach (1996); Ferris Bueller's Day Off (1986); Secret Garden, The (1993); Toy Story 2 (1999); Airplane! (1980); Pleasantville (1998); Dumbo (1941); Princess Bride, The (1987); Snow White and the Seven Dwarfs (1937); Miracle on 34th Street (1947); Ponette (1996); Schindler's List (1993); Beauty and the Beast (1991); Tarzan (1999); Close Shave, A (1995); Aladdin (1992); Toy Story (1995); Bug's Life, A (1998); Antz (1998); Hunchback of Notre Dame, The (1996); Hercules (1997); Mulan (1998); Pocahontas (1995); 
1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 653ms/step
Forrest Gump (1994)
Aladdin (1992)
Bug's Life, A (1998)
As Good As It Gets (1997)
Clueless (1995)
Ghostbusters (1984)
American Beauty (1999)
Groundhog Day (1993)
Toy Story (1995)
Four Weddings and a Funeral (1994)

大功告成!