作者: fchollet
创建日期 2016/01/06
上次修改日期 2023/11/20
描述: 自定义层创建的演示。
此示例展示了如何创建自定义层,使用 Antirectifier 层(最初在 2016 年 1 月作为 Keras 示例脚本提出),它是 ReLU 的替代方案。它不是将输入的负数部分置零,而是将负数和正数部分分开,并返回两者绝对值的串联。这避免了信息丢失,但代价是增加了维度。为了修复维度增加的问题,我们将特征线性组合回原始大小的空间。
import keras
from keras import layers
from keras import ops
要实现自定义层
__init__
或 build()
中的 add_weight()
创建状态变量。同样,你也可以创建子层。call()
方法,接受层的输入张量并返回输出张量。get_config()
来启用序列化,该方法返回一个配置字典。另请参阅指南 通过子类化创建新层和模型。
class Antirectifier(layers.Layer):
def __init__(self, initializer="he_normal", **kwargs):
super().__init__(**kwargs)
self.initializer = keras.initializers.get(initializer)
def build(self, input_shape):
output_dim = input_shape[-1]
self.kernel = self.add_weight(
shape=(output_dim * 2, output_dim),
initializer=self.initializer,
name="kernel",
trainable=True,
)
def call(self, inputs):
inputs -= ops.mean(inputs, axis=-1, keepdims=True)
pos = ops.relu(inputs)
neg = ops.relu(-inputs)
concatenated = ops.concatenate([pos, neg], axis=-1)
mixed = ops.matmul(concatenated, self.kernel)
return mixed
def get_config(self):
# Implement get_config to enable serialization. This is optional.
base_config = super().get_config()
config = {"initializer": keras.initializers.serialize(self.initializer)}
return dict(list(base_config.items()) + list(config.items()))
# Training parameters
batch_size = 128
num_classes = 10
epochs = 20
# The data, split between train and test sets
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = x_train.reshape(-1, 784)
x_test = x_test.reshape(-1, 784)
x_train = x_train.astype("float32")
x_test = x_test.astype("float32")
x_train /= 255
x_test /= 255
print(x_train.shape[0], "train samples")
print(x_test.shape[0], "test samples")
# Build the model
model = keras.Sequential(
[
keras.Input(shape=(784,)),
layers.Dense(256),
Antirectifier(),
layers.Dense(256),
Antirectifier(),
layers.Dropout(0.5),
layers.Dense(10),
]
)
# Compile the model
model.compile(
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=keras.optimizers.RMSprop(),
metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
# Train the model
model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, validation_split=0.15)
# Test the model
model.evaluate(x_test, y_test)
60000 train samples
10000 test samples
Epoch 1/20
399/399 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - loss: 0.6226 - sparse_categorical_accuracy: 0.8146 - val_loss: 0.4256 - val_sparse_categorical_accuracy: 0.8808
Epoch 2/20
399/399 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step - loss: 0.1887 - sparse_categorical_accuracy: 0.9455 - val_loss: 0.1556 - val_sparse_categorical_accuracy: 0.9588
Epoch 3/20
399/399 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step - loss: 0.1406 - sparse_categorical_accuracy: 0.9608 - val_loss: 0.1531 - val_sparse_categorical_accuracy: 0.9611
Epoch 4/20
399/399 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step - loss: 0.1084 - sparse_categorical_accuracy: 0.9691 - val_loss: 0.1178 - val_sparse_categorical_accuracy: 0.9731
Epoch 5/20
399/399 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step - loss: 0.0995 - sparse_categorical_accuracy: 0.9738 - val_loss: 0.2207 - val_sparse_categorical_accuracy: 0.9526
Epoch 6/20
399/399 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step - loss: 0.0831 - sparse_categorical_accuracy: 0.9769 - val_loss: 0.2092 - val_sparse_categorical_accuracy: 0.9533
Epoch 7/20
399/399 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step - loss: 0.0736 - sparse_categorical_accuracy: 0.9807 - val_loss: 0.1129 - val_sparse_categorical_accuracy: 0.9749
Epoch 8/20
399/399 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step - loss: 0.0653 - sparse_categorical_accuracy: 0.9827 - val_loss: 0.1000 - val_sparse_categorical_accuracy: 0.9791
Epoch 9/20
399/399 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step - loss: 0.0590 - sparse_categorical_accuracy: 0.9833 - val_loss: 0.1320 - val_sparse_categorical_accuracy: 0.9750
Epoch 10/20
399/399 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step - loss: 0.0587 - sparse_categorical_accuracy: 0.9854 - val_loss: 0.1439 - val_sparse_categorical_accuracy: 0.9747
Epoch 11/20
399/399 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step - loss: 0.0622 - sparse_categorical_accuracy: 0.9853 - val_loss: 0.1473 - val_sparse_categorical_accuracy: 0.9753
Epoch 12/20
399/399 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step - loss: 0.0554 - sparse_categorical_accuracy: 0.9869 - val_loss: 0.1529 - val_sparse_categorical_accuracy: 0.9757
Epoch 13/20
399/399 ━━━━━━━━━━━━━━━━━━━━ 1s 3ms/step - loss: 0.0507 - sparse_categorical_accuracy: 0.9884 - val_loss: 0.1452 - val_sparse_categorical_accuracy: 0.9783
Epoch 14/20
399/399 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step - loss: 0.0468 - sparse_categorical_accuracy: 0.9889 - val_loss: 0.1435 - val_sparse_categorical_accuracy: 0.9796
Epoch 15/20
399/399 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step - loss: 0.0478 - sparse_categorical_accuracy: 0.9892 - val_loss: 0.1580 - val_sparse_categorical_accuracy: 0.9770
Epoch 16/20
399/399 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step - loss: 0.0492 - sparse_categorical_accuracy: 0.9888 - val_loss: 0.1957 - val_sparse_categorical_accuracy: 0.9753
Epoch 17/20
399/399 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step - loss: 0.0478 - sparse_categorical_accuracy: 0.9896 - val_loss: 0.1865 - val_sparse_categorical_accuracy: 0.9779
Epoch 18/20
399/399 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step - loss: 0.0478 - sparse_categorical_accuracy: 0.9893 - val_loss: 0.2107 - val_sparse_categorical_accuracy: 0.9747
Epoch 19/20
399/399 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step - loss: 0.0494 - sparse_categorical_accuracy: 0.9894 - val_loss: 0.2306 - val_sparse_categorical_accuracy: 0.9734
Epoch 20/20
399/399 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step - loss: 0.0473 - sparse_categorical_accuracy: 0.9910 - val_loss: 0.2201 - val_sparse_categorical_accuracy: 0.9731
313/313 ━━━━━━━━━━━━━━━━━━━━ 0s 802us/step - loss: 0.2086 - sparse_categorical_accuracy: 0.9710
[0.19070196151733398, 0.9740999937057495]