Retrieval
类keras_rs.layers.Retrieval(k: int = 10, return_scores: bool = True, **kwargs: Any)
Retrieval 基础抽象类。
此层为所有检索层提供了一个通用接口。为了实现自定义检索层,应继承此抽象类。
参数
True
时,此层返回一个包含最高得分和最高标识符的元组。当 False
时,此层返回一个包含最高标识符的单一张量。call
方法Retrieval.call(inputs: Any)
返回作为输入的查询的顶部候选者。
参数
返回
如果 returns_scores
为 True,则返回一个包含最高得分和最高标识符的元组,否则返回一个包含最高标识符的张量。
update_candidates
方法Retrieval.update_candidates(
candidate_embeddings: Any, candidate_ids: Optional[Any] = None
)
更新候选者集合并可选地更新其候选者 ID。
参数
None
,则返回候选者的索引。compute_score
方法Retrieval.compute_score(query_embedding: Any, candidate_embedding: Any)
计算来自查询和候选者的标准点积得分。
参数
返回
查询和候选者的点积。