BruteForceRetrieval 类keras_rs.layers.BruteForceRetrieval(
candidate_embeddings: Optional[Any] = None,
candidate_ids: Optional[Any] = None,
k: int = 10,
return_scores: bool = True,
**kwargs: Any
)
暴力搜索 Top-K 检索。
此层维护一个候选集,并且能够精确地检索给定查询的 Top-K 候选。它通过计算所有候选与查询的得分,然后提取 Top 元素来实现。返回的 Top-K 候选按得分排序。
默认情况下,此层返回一个元组,包含 Top 分数和 Top 标识符,但也可以配置为返回一个包含 Top 标识符的单个张量。
候选的标识符可以指定为张量。如果未提供,则使用的 ID 仅为候选索引。
请注意,此层的序列化不会保留候选集,只会保存 k 和 return_scores 参数。反序列化层后,必须调用 update_candidates。
参数
None,则必须在使用此层之前通过 update_candidates 提供候选。None,则返回候选的索引。True 时,此层返回一个元组,包含 Top 分数和 Top 标识符。当为 False 时,此层返回一个包含 Top 标识符的单个张量。示例
retrieval = keras_rs.layers.BruteForceRetrieval(k=100)
# At some later point, we update the candidates.
retrieval.update_candidates(candidate_embeddings, candidate_ids)
# We can then retrieve the top candidates for any number of queries.
# Scores are stored highest first. Scores correspond to ids in the same row.
tops_scores, top_ids = retrieval(query_embeddings)
call 方法BruteForceRetrieval.call(inputs: Any)
返回输入查询的 Top 候选。
参数
返回
如果 returns_scores 为 True,则返回一个元组,包含 Top 分数和 Top 标识符;否则,返回一个包含 Top 标识符的张量。
update_candidates 方法BruteForceRetrieval.update_candidates(
candidate_embeddings: Any, candidate_ids: Optional[Any] = None
)
更新候选集以及可选的候选 ID。
参数
None,则返回候选的索引。