Keras 3 API 文档 / 指标 / 基本指标类

基本指标类

[来源]

Metric

keras.metrics.Metric(dtype=None, name=None)

封装指标逻辑和状态。

参数

  • name: 指标实例的可选名称。
  • dtype: 指标计算的数据类型。默认为 None,这意味着使用 keras.backend.floatx()keras.backend.floatx() 是一个 "float32",除非设置为不同的值(通过 keras.backend.set_floatx())。如果提供 keras.DTypePolicy,则将使用 compute_dtype

示例

m = SomeMetric(...)
for input in ...:
    m.update_state(input)
print('Final result: ', m.result())

compile() API 一起使用

model = keras.Sequential()
model.add(keras.layers.Dense(64, activation='relu'))
model.add(keras.layers.Dense(64, activation='relu'))
model.add(keras.layers.Dense(10, activation='softmax'))

model.compile(optimizer=keras.optimizers.RMSprop(0.01),
              loss=keras.losses.CategoricalCrossentropy(),
              metrics=[keras.metrics.CategoricalAccuracy()])

data = np.random.random((1000, 32))
labels = np.random.random((1000, 10))

model.fit(data, labels, epochs=10)

子类需要实现

  • __init__(): 所有状态变量都应该通过调用 self.add_variable() 在此方法中创建,例如:self.var = self.add_variable(...)
  • update_state(): 包含对状态变量的所有更新,例如:self.var.assign(...)
  • result(): 从状态变量计算并返回指标的标量值或标量值字典。

示例子类实现

class BinaryTruePositives(Metric):

    def __init__(self, name='binary_true_positives', **kwargs):
        super().__init__(name=name, **kwargs)
        self.true_positives = self.add_variable(
            shape=(),
            initializer='zeros',
            name='true_positives'
        )

    def update_state(self, y_true, y_pred, sample_weight=None):
        y_true = ops.cast(y_true, "bool")
        y_pred = ops.cast(y_pred, "bool")

        values = ops.logical_and(
            ops.equal(y_true, True), ops.equal(y_pred, True))
        values = ops.cast(values, self.dtype)
        if sample_weight is not None:
            sample_weight = ops.cast(sample_weight, self.dtype)
            sample_weight = ops.broadcast_to(
                sample_weight, ops.shape(values)
            )
            values = ops.multiply(values, sample_weight)
        self.true_positives.assign(self.true_positives + ops.sum(values))

    def result(self):
        return self.true_positives