BruteForceRetrieval
类keras_rs.layers.BruteForceRetrieval(
candidate_embeddings: Optional[Any] = None,
candidate_ids: Optional[Any] = None,
k: int = 10,
return_scores: bool = True,
**kwargs: Any
)
暴力法 top-k 检索。
该层维护一组候选,并能够为给定查询精确检索 top-k 候选。它通过计算查询的所有候选的分数并提取得分最高的候选项来实现此功能。返回的 top-k 候选按分数排序。
默认情况下,此层返回一个包含 top 分数和 top 标识符的元组,但可以配置为返回一个包含 top 标识符的单个张量。
候选的标识符可以指定为张量。如果未提供,则使用的 ID 仅是候选索引。
请注意,此层的序列化不会保留候选,仅保存 k
和 return_scores
参数。在反序列化层后,必须调用 update_candidates
。
参数
None
,则在使用此层之前必须使用 update_candidates
提供候选。None
,则返回候选的索引。True
时,此层返回一个包含 top 分数和 top 标识符的元组。当为 False
时,此层返回一个包含 top 标识符的单个张量。示例
retrieval = keras_rs.layers.BruteForceRetrieval(k=100)
# At some later point, we update the candidates.
retrieval.update_candidates(candidate_embeddings, candidate_ids)
# We can then retrieve the top candidates for any number of queries.
# Scores are stored highest first. Scores correspond to ids in the same row.
tops_scores, top_ids = retrieval(query_embeddings)
call
方法BruteForceRetrieval.call(inputs: Any)
返回作为输入传入的查询的 top 候选。
参数
返回值
如果 returns_scores
为 True,则返回一个包含 top 分数和 top 标识符的元组,否则返回一个包含 top 标识符的张量。
update_candidates
方法BruteForceRetrieval.update_candidates(
candidate_embeddings: Any, candidate_ids: Optional[Any] = None
)
更新候选集并可选地更新其候选 ID。
参数
None
,则返回候选的索引。