代码示例 / 快速 Keras 食谱 / 简单的自定义层示例:反整流器

简单的自定义层示例:反整流器

作者: fchollet
创建日期 2016/01/06
最后修改 2023/11/20
描述:演示自定义层创建。

ⓘ 此示例使用 Keras 3

在 Colab 中查看 GitHub 源代码


介绍

此示例展示如何创建自定义层,使用反整流器层(最初于 2016 年 1 月作为 Keras 示例脚本提出),作为 ReLU 的替代方案。它不会将输入的负数部分归零,而是将负数部分和正数部分分开,并返回两者的绝对值的拼接。这样可以避免信息丢失,但会增加维度。为了解决维度增加的问题,我们线性地将特征组合回原始大小的空间。


设置

import keras
from keras import layers
from keras import ops

反整流器层

要实现自定义层

  • __init__build() 中通过 add_weight() 创建状态变量。类似地,您也可以创建子层。
  • 实现 call() 方法,该方法接受层的输入张量,并返回输出张量。
  • 可选地,您也可以通过实现 get_config() 来启用序列化,该方法返回一个配置字典。

另请参阅指南 通过子类化创建新的层和模型

class Antirectifier(layers.Layer):
    def __init__(self, initializer="he_normal", **kwargs):
        super().__init__(**kwargs)
        self.initializer = keras.initializers.get(initializer)

    def build(self, input_shape):
        output_dim = input_shape[-1]
        self.kernel = self.add_weight(
            shape=(output_dim * 2, output_dim),
            initializer=self.initializer,
            name="kernel",
            trainable=True,
        )

    def call(self, inputs):
        inputs -= ops.mean(inputs, axis=-1, keepdims=True)
        pos = ops.relu(inputs)
        neg = ops.relu(-inputs)
        concatenated = ops.concatenate([pos, neg], axis=-1)
        mixed = ops.matmul(concatenated, self.kernel)
        return mixed

    def get_config(self):
        # Implement get_config to enable serialization. This is optional.
        base_config = super().get_config()
        config = {"initializer": keras.initializers.serialize(self.initializer)}
        return dict(list(base_config.items()) + list(config.items()))

让我们在 MNIST 上试用它

# Training parameters
batch_size = 128
num_classes = 10
epochs = 20

# The data, split between train and test sets
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

x_train = x_train.reshape(-1, 784)
x_test = x_test.reshape(-1, 784)
x_train = x_train.astype("float32")
x_test = x_test.astype("float32")
x_train /= 255
x_test /= 255
print(x_train.shape[0], "train samples")
print(x_test.shape[0], "test samples")

# Build the model
model = keras.Sequential(
    [
        keras.Input(shape=(784,)),
        layers.Dense(256),
        Antirectifier(),
        layers.Dense(256),
        Antirectifier(),
        layers.Dropout(0.5),
        layers.Dense(10),
    ]
)

# Compile the model
model.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=keras.optimizers.RMSprop(),
    metrics=[keras.metrics.SparseCategoricalAccuracy()],
)

# Train the model
model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, validation_split=0.15)

# Test the model
model.evaluate(x_test, y_test)
60000 train samples
10000 test samples
Epoch 1/20
 399/399 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/step - loss: 0.6226 - sparse_categorical_accuracy: 0.8146 - val_loss: 0.4256 - val_sparse_categorical_accuracy: 0.8808
Epoch 2/20
 399/399 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step - loss: 0.1887 - sparse_categorical_accuracy: 0.9455 - val_loss: 0.1556 - val_sparse_categorical_accuracy: 0.9588
Epoch 3/20
 399/399 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step - loss: 0.1406 - sparse_categorical_accuracy: 0.9608 - val_loss: 0.1531 - val_sparse_categorical_accuracy: 0.9611
Epoch 4/20
 399/399 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step - loss: 0.1084 - sparse_categorical_accuracy: 0.9691 - val_loss: 0.1178 - val_sparse_categorical_accuracy: 0.9731
Epoch 5/20
 399/399 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step - loss: 0.0995 - sparse_categorical_accuracy: 0.9738 - val_loss: 0.2207 - val_sparse_categorical_accuracy: 0.9526
Epoch 6/20
 399/399 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step - loss: 0.0831 - sparse_categorical_accuracy: 0.9769 - val_loss: 0.2092 - val_sparse_categorical_accuracy: 0.9533
Epoch 7/20
 399/399 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step - loss: 0.0736 - sparse_categorical_accuracy: 0.9807 - val_loss: 0.1129 - val_sparse_categorical_accuracy: 0.9749
Epoch 8/20
 399/399 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step - loss: 0.0653 - sparse_categorical_accuracy: 0.9827 - val_loss: 0.1000 - val_sparse_categorical_accuracy: 0.9791
Epoch 9/20
 399/399 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step - loss: 0.0590 - sparse_categorical_accuracy: 0.9833 - val_loss: 0.1320 - val_sparse_categorical_accuracy: 0.9750
Epoch 10/20
 399/399 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step - loss: 0.0587 - sparse_categorical_accuracy: 0.9854 - val_loss: 0.1439 - val_sparse_categorical_accuracy: 0.9747
Epoch 11/20
 399/399 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step - loss: 0.0622 - sparse_categorical_accuracy: 0.9853 - val_loss: 0.1473 - val_sparse_categorical_accuracy: 0.9753
Epoch 12/20
 399/399 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step - loss: 0.0554 - sparse_categorical_accuracy: 0.9869 - val_loss: 0.1529 - val_sparse_categorical_accuracy: 0.9757
Epoch 13/20
 399/399 ━━━━━━━━━━━━━━━━━━━━ 1s 3ms/step - loss: 0.0507 - sparse_categorical_accuracy: 0.9884 - val_loss: 0.1452 - val_sparse_categorical_accuracy: 0.9783
Epoch 14/20
 399/399 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step - loss: 0.0468 - sparse_categorical_accuracy: 0.9889 - val_loss: 0.1435 - val_sparse_categorical_accuracy: 0.9796
Epoch 15/20
 399/399 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step - loss: 0.0478 - sparse_categorical_accuracy: 0.9892 - val_loss: 0.1580 - val_sparse_categorical_accuracy: 0.9770
Epoch 16/20
 399/399 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step - loss: 0.0492 - sparse_categorical_accuracy: 0.9888 - val_loss: 0.1957 - val_sparse_categorical_accuracy: 0.9753
Epoch 17/20
 399/399 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step - loss: 0.0478 - sparse_categorical_accuracy: 0.9896 - val_loss: 0.1865 - val_sparse_categorical_accuracy: 0.9779
Epoch 18/20
 399/399 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step - loss: 0.0478 - sparse_categorical_accuracy: 0.9893 - val_loss: 0.2107 - val_sparse_categorical_accuracy: 0.9747
Epoch 19/20
 399/399 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step - loss: 0.0494 - sparse_categorical_accuracy: 0.9894 - val_loss: 0.2306 - val_sparse_categorical_accuracy: 0.9734
Epoch 20/20
 399/399 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step - loss: 0.0473 - sparse_categorical_accuracy: 0.9910 - val_loss: 0.2201 - val_sparse_categorical_accuracy: 0.9731
 313/313 ━━━━━━━━━━━━━━━━━━━━ 0s 802us/step - loss: 0.2086 - sparse_categorical_accuracy: 0.9710

[0.19070196151733398, 0.9740999937057495]