Perplexity 类keras_hub.metrics.Perplexity(
from_logits=False, mask_token_id=None, dtype="float32", name="perplexity", **kwargs
)
困惑度(Perplexity)指标。
此类实现了困惑度指标。简而言之,此类计算交叉熵损失并取其指数。注意:此实现不适用于固定大小的窗口。
参数
y_pred(传递给 update_state() 的输入)应该是模型返回的 logits。否则,y_pred 是一个概率张量。update_state() 中还提供了 sample_weight 字段,我们将计算最终的 sample_weight 作为掩码和 sample_weight 的元素乘积。"float32"。示例
sample_weight 和 mask_token_id。>>> np.random.seed(42)
>>> perplexity = keras_hub.metrics.Perplexity(name="perplexity")
>>> target = np.random.randint(10, size=[2, 5])
>>> logits = np.random.uniform(size=(2, 5, 10))
>>> perplexity.update_state(target, logits)
>>> perplexity.result()
<tf.Tensor: shape=(), dtype=float32, numpy=14.352535>
1.2. 指定了 sample_weight(用 ID 0 屏蔽 token)。
>>> np.random.seed(42)
>>> perplexity = keras_hub.metrics.Perplexity(name="perplexity")
>>> target = np.random.randint(10, size=[2, 5])
>>> logits = np.random.uniform(size=(2, 5, 10))
>>> sample_weight = (target != 0).astype("float32")
>>> perplexity.update_state(target, logits, sample_weight)
>>> perplexity.result()
<tf.Tensor: shape=(), dtype=float32, numpy=14.352535>
>>> np.random.seed(42)
>>> perplexity = keras_hub.metrics.Perplexity(name="perplexity")
>>> target = np.random.randint(10, size=[2, 5])
>>> logits = np.random.uniform(size=(2, 5, 10))
>>> perplexity(target, logits)
<tf.Tensor: shape=(), dtype=float32, numpy=14.352535>
>>> np.random.seed(42)
>>> perplexity = keras_hub.metrics.Perplexity(mask_token_id=0)
>>> target = np.random.randint(10, size=[2, 5])
>>> logits = np.random.uniform(size=(2, 5, 10))
>>> perplexity(target, logits)
<tf.Tensor: shape=(), dtype=float32, numpy=14.352535>