作者: fchollet
创建日期 2019/05/10
上次修改日期 2023/11/22
描述:演示“端点层”模式(处理损失管理的层)。
import os
os.environ["KERAS_BACKEND"] = "tensorflow"
import tensorflow as tf
import keras
import numpy as np
“端点层”可以访问模型的目标,并在 call()
中使用 self.add_loss()
和 Metric.update_state()
创建任意损失。这使您可以定义不匹配常规签名 fn(y_true, y_pred, sample_weight=None)
的损失和指标。
请注意,使用此模式可以为训练和评估设置单独的指标。
class LogisticEndpoint(keras.layers.Layer):
def __init__(self, name=None):
super().__init__(name=name)
self.loss_fn = keras.losses.BinaryCrossentropy(from_logits=True)
self.accuracy_metric = keras.metrics.BinaryAccuracy(name="accuracy")
def call(self, logits, targets=None, sample_weight=None):
if targets is not None:
# Compute the training-time loss value and add it
# to the layer using `self.add_loss()`.
loss = self.loss_fn(targets, logits, sample_weight)
self.add_loss(loss)
# Log the accuracy as a metric (we could log arbitrary metrics,
# including different metrics for training and inference.)
self.accuracy_metric.update_state(targets, logits, sample_weight)
# Return the inference-time prediction tensor (for `.predict()`).
return tf.nn.softmax(logits)
inputs = keras.Input((764,), name="inputs")
logits = keras.layers.Dense(1)(inputs)
targets = keras.Input((1,), name="targets")
sample_weight = keras.Input((1,), name="sample_weight")
preds = LogisticEndpoint()(logits, targets, sample_weight)
model = keras.Model([inputs, targets, sample_weight], preds)
data = {
"inputs": np.random.random((1000, 764)),
"targets": np.random.random((1000, 1)),
"sample_weight": np.random.random((1000, 1)),
}
model.compile(keras.optimizers.Adam(1e-3))
model.fit(data, epochs=2)
Epoch 1/2
27/32 ━━━━━━━━━━━━━━━━[37m━━━━ 0s 2ms/step - loss: 0.3664
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1700705222.380735 3351467 device_compiler.h:186] Compiled cluster using XLA! This line is logged at most once for the lifetime of the process.
32/32 ━━━━━━━━━━━━━━━━━━━━ 2s 31ms/step - loss: 0.3663
Epoch 2/2
32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.3627
<keras.src.callbacks.history.History at 0x7f13401b1e10>
只需不要在模型中包含 targets
。权重保持不变。
inputs = keras.Input((764,), name="inputs")
logits = keras.layers.Dense(1)(inputs)
preds = LogisticEndpoint()(logits, targets=None, sample_weight=None)
inference_model = keras.Model(inputs, preds)
inference_model.set_weights(model.get_weights())
preds = inference_model.predict(np.random.random((1000, 764)))
32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step
class LogReg(keras.Model):
def __init__(self):
super().__init__()
self.dense = keras.layers.Dense(1)
self.logistic_endpoint = LogisticEndpoint()
def call(self, inputs):
# Note that all inputs should be in the first argument
# since we want to be able to call `model.fit(inputs)`.
logits = self.dense(inputs["inputs"])
preds = self.logistic_endpoint(
logits=logits,
targets=inputs["targets"],
sample_weight=inputs["sample_weight"],
)
return preds
model = LogReg()
data = {
"inputs": np.random.random((1000, 764)),
"targets": np.random.random((1000, 1)),
"sample_weight": np.random.random((1000, 1)),
}
model.compile(keras.optimizers.Adam(1e-3))
model.fit(data, epochs=2)
Epoch 1/2
32/32 ━━━━━━━━━━━━━━━━━━━━ 1s 9ms/step - loss: 0.3529
Epoch 2/2
32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.3509
<keras.src.callbacks.history.History at 0x7f132c1d1450>