Keras 3 API 文档 / KerasHub / 指标 / 困惑度指标

困惑度指标

[源代码]

Perplexity

keras_hub.metrics.Perplexity(
    from_logits=False, mask_token_id=None, dtype="float32", name="perplexity", **kwargs
)

困惑度指标。

此类实现困惑度指标。简而言之,此类计算交叉熵损失并取其指数。注意:此实现不适用于固定大小的窗口。

参数

  • from_logits: bool。如果为 True,则 y_pred(输入到 update_state())应为模型返回的对数几率。否则,y_pred 是一个概率张量。
  • mask_token_id: int。要屏蔽的标记的 ID。如果提供,则为此类计算掩码。请注意,如果提供了此字段,并且如果在 update_state() 中也提供了 sample_weight 字段,我们将计算最终的 sample_weight 作为掩码和 sample_weight 的逐元素乘积。
  • dtype: 字符串或 tf.dtypes.Dtype。指标计算的精度。如果未指定,则默认为 "float32"
  • name: 字符串。指标实例的名称。
  • **kwargs: 其他关键字参数。

示例

  1. 通过调用 update_state() 和 result() 计算困惑度。1.1. 未提供 sample_weightmask_token_id
>>> np.random.seed(42)
>>> perplexity = keras_hub.metrics.Perplexity(name="perplexity")
>>> target = np.random.randint(10, size=[2, 5])
>>> logits = np.random.uniform(size=(2, 5, 10))
>>> perplexity.update_state(target, logits)
>>> perplexity.result()
<tf.Tensor: shape=(), dtype=float32, numpy=14.352535>

1.2. 指定 sample_weight(使用 ID 为 0 的标记进行屏蔽)。

>>> np.random.seed(42)
>>> perplexity = keras_hub.metrics.Perplexity(name="perplexity")
>>> target = np.random.randint(10, size=[2, 5])
>>> logits = np.random.uniform(size=(2, 5, 10))
>>> sample_weight = (target != 0).astype("float32")
>>> perplexity.update_state(target, logits, sample_weight)
>>> perplexity.result()
<tf.Tensor: shape=(), dtype=float32, numpy=14.352535>
  1. 直接调用困惑度。
>>> np.random.seed(42)
>>> perplexity = keras_hub.metrics.Perplexity(name="perplexity")
>>> target = np.random.randint(10, size=[2, 5])
>>> logits = np.random.uniform(size=(2, 5, 10))
>>> perplexity(target, logits)
<tf.Tensor: shape=(), dtype=float32, numpy=14.352535>
  1. 提供填充标记 ID 并让类自行计算掩码。
>>> np.random.seed(42)
>>> perplexity = keras_hub.metrics.Perplexity(mask_token_id=0)
>>> target = np.random.randint(10, size=[2, 5])
>>> logits = np.random.uniform(size=(2, 5, 10))
>>> perplexity(target, logits)
<tf.Tensor: shape=(), dtype=float32, numpy=14.352535>