KerasRS / API 文档 / 检索层 / BruteForceRetrieval 层

BruteForceRetrieval 层

[源代码]

BruteForceRetrieval

keras_rs.layers.BruteForceRetrieval(
    candidate_embeddings: Optional[Any] = None,
    candidate_ids: Optional[Any] = None,
    k: int = 10,
    return_scores: bool = True,
    **kwargs: Any
)

暴力法 top-k 检索。

该层维护一组候选,并能够为给定查询精确检索 top-k 候选。它通过计算查询的所有候选的分数并提取得分最高的候选项来实现此功能。返回的 top-k 候选按分数排序。

默认情况下,此层返回一个包含 top 分数和 top 标识符的元组,但可以配置为返回一个包含 top 标识符的单个张量。

候选的标识符可以指定为张量。如果未提供,则使用的 ID 仅是候选索引。

请注意,此层的序列化不会保留候选,仅保存 kreturn_scores 参数。在反序列化层后,必须调用 update_candidates

参数

  • candidate_embeddings: 候选嵌入。如果为 None,则在使用此层之前必须使用 update_candidates 提供候选。
  • candidate_ids: 候选的标识符。如果为 None,则返回候选的索引。
  • k: 要检索的候选数量。
  • return_scores: 当为 True 时,此层返回一个包含 top 分数和 top 标识符的元组。当为 False 时,此层返回一个包含 top 标识符的单个张量。
  • **kwargs: 传递给基类的参数。

示例

retrieval = keras_rs.layers.BruteForceRetrieval(k=100)

# At some later point, we update the candidates.
retrieval.update_candidates(candidate_embeddings, candidate_ids)

# We can then retrieve the top candidates for any number of queries.
# Scores are stored highest first. Scores correspond to ids in the same row.
tops_scores, top_ids = retrieval(query_embeddings)

[源代码]

call 方法

BruteForceRetrieval.call(inputs: Any)

返回作为输入传入的查询的 top 候选。

参数

  • inputs: 要返回 top 候选的查询。

返回值

如果 returns_scores 为 True,则返回一个包含 top 分数和 top 标识符的元组,否则返回一个包含 top 标识符的张量。


[源代码]

update_candidates 方法

BruteForceRetrieval.update_candidates(
    candidate_embeddings: Any, candidate_ids: Optional[Any] = None
)

更新候选集并可选地更新其候选 ID。

参数

  • candidate_embeddings: 候选嵌入。
  • candidate_ids: 候选的标识符。如果为 None,则返回候选的索引。