KerasRS / API 文档 / 检索层 / 检索层

检索层

[来源]

Retrieval

keras_rs.layers.Retrieval(k: int = 10, return_scores: bool = True, **kwargs: Any)

Retrieval 基础抽象类。

此层为所有检索层提供了一个通用接口。为了实现自定义检索层,应继承此抽象类。

参数

  • k: int. 检索候选的数量。
  • return_scores: bool. 当 True 时,此层返回一个包含最高得分和最高标识符的元组。当 False 时,此层返回一个包含最高标识符的单一张量。

[来源]

call 方法

Retrieval.call(inputs: Any)

返回作为输入的查询的顶部候选者。

参数

  • inputs: 需要返回顶部候选者的查询。

返回

如果 returns_scores 为 True,则返回一个包含最高得分和最高标识符的元组,否则返回一个包含最高标识符的张量。


[来源]

update_candidates 方法

Retrieval.update_candidates(
    candidate_embeddings: Any, candidate_ids: Optional[Any] = None
)

更新候选者集合并可选地更新其候选者 ID。

参数

  • candidate_embeddings: 候选者嵌入。
  • candidate_ids: 候选者的标识符。如果为 None,则返回候选者的索引。

[来源]

compute_score 方法

Retrieval.compute_score(query_embedding: Any, candidate_embedding: Any)

计算来自查询和候选者的标准点积得分。

参数

  • query_embedding: 查询嵌入张量,对应于需要检索顶部候选者的查询。
  • candidate_embedding: 候选者嵌入张量。

返回

查询和候选者的点积。