作者: Sayak Paul
创建日期 2021/10/12
上次修改日期 2021/10/12
描述:应用于图像块的全卷积网络。
视觉 Transformer(ViT;Dosovitskiy 等人)从输入图像中提取小的块,对其进行线性投影,然后应用 Transformer(Vaswani 等人)块。ViT 在图像识别任务中的应用正迅速成为一个有前景的研究领域,因为 ViT 消除了对具有强归纳偏差(如卷积)来建模局部性的需求。这使得它们成为一种通用计算原语,能够仅从训练数据中学习,并尽可能减少归纳先验。当使用适当的正则化、数据增强和相对较大的数据集进行训练时,ViT 会产生良好的下游性能。
在Patches Are All You Need论文中(注意:在撰写本文时,它是 ICLR 2022 会议的投稿),作者扩展了使用块训练全卷积网络的想法,并展示了具有竞争力的结果。他们的架构,即 **ConvMixer**,使用了来自最近各向同性架构(如 ViT、MLP-Mixer(Tolstikhin 等人))的技巧,例如在网络中不同层之间使用相同的深度和分辨率、残差连接等等。
在本例中,我们将实现 ConvMixer 模型,并在 CIFAR-10 数据集上演示其性能。
import keras
from keras import layers
import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np
为了缩短运行时间,我们将仅训练模型 10 个 epoch。为了专注于 ConvMixer 的核心思想,我们不会使用其他特定于训练的元素,例如 RandAugment(Cubuk 等人)。如果您有兴趣了解更多关于这些细节的信息,请参考原始论文。
learning_rate = 0.001
weight_decay = 0.0001
batch_size = 128
num_epochs = 10
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
val_split = 0.1
val_indices = int(len(x_train) * val_split)
new_x_train, new_y_train = x_train[val_indices:], y_train[val_indices:]
x_val, y_val = x_train[:val_indices], y_train[:val_indices]
print(f"Training data samples: {len(new_x_train)}")
print(f"Validation data samples: {len(x_val)}")
print(f"Test data samples: {len(x_test)}")
Training data samples: 45000
Validation data samples: 5000
Test data samples: 10000
tf.data.Dataset
对象我们的数据增强管道与作者用于 CIFAR-10 数据集的管道不同,这对于示例的目的来说是可以的。请注意,可以使用 **TF API 进行数据 I/O 和预处理** 与其他后端(jax、torch)一起使用,因为在数据预处理方面它是一个功能完整的框架。
image_size = 32
auto = tf.data.AUTOTUNE
augmentation_layers = [
keras.layers.RandomCrop(image_size, image_size),
keras.layers.RandomFlip("horizontal"),
]
def augment_images(images):
for layer in augmentation_layers:
images = layer(images, training=True)
return images
def make_datasets(images, labels, is_train=False):
dataset = tf.data.Dataset.from_tensor_slices((images, labels))
if is_train:
dataset = dataset.shuffle(batch_size * 10)
dataset = dataset.batch(batch_size)
if is_train:
dataset = dataset.map(
lambda x, y: (augment_images(x), y), num_parallel_calls=auto
)
return dataset.prefetch(auto)
train_dataset = make_datasets(new_x_train, new_y_train, is_train=True)
val_dataset = make_datasets(x_val, y_val)
test_dataset = make_datasets(x_test, y_test)
下图(取自原始论文)描绘了 ConvMixer 模型
ConvMixer 与 MLP-Mixer 模型非常相似,但存在以下关键区别
ConvMixer 中使用了两种类型的卷积层。**(1)**:深度卷积,用于混合图像的空间位置,**(2)**:点卷积(深度卷积之后),用于混合块之间的通道信息。另一个关键点是使用 *更大的内核大小* 以允许更大的感受野。
def activation_block(x):
x = layers.Activation("gelu")(x)
return layers.BatchNormalization()(x)
def conv_stem(x, filters: int, patch_size: int):
x = layers.Conv2D(filters, kernel_size=patch_size, strides=patch_size)(x)
return activation_block(x)
def conv_mixer_block(x, filters: int, kernel_size: int):
# Depthwise convolution.
x0 = x
x = layers.DepthwiseConv2D(kernel_size=kernel_size, padding="same")(x)
x = layers.Add()([activation_block(x), x0]) # Residual.
# Pointwise convolution.
x = layers.Conv2D(filters, kernel_size=1)(x)
x = activation_block(x)
return x
def get_conv_mixer_256_8(
image_size=32, filters=256, depth=8, kernel_size=5, patch_size=2, num_classes=10
):
"""ConvMixer-256/8: https://openreview.net/pdf?id=TVHS5Y4dNvM.
The hyperparameter values are taken from the paper.
"""
inputs = keras.Input((image_size, image_size, 3))
x = layers.Rescaling(scale=1.0 / 255)(inputs)
# Extract patch embeddings.
x = conv_stem(x, filters, patch_size)
# ConvMixer blocks.
for _ in range(depth):
x = conv_mixer_block(x, filters, kernel_size)
# Classification block.
x = layers.GlobalAvgPool2D()(x)
outputs = layers.Dense(num_classes, activation="softmax")(x)
return keras.Model(inputs, outputs)
在本实验中使用的模型称为 **ConvMixer-256/8**,其中 256 表示通道数,8 表示深度。生成的模型只有 0.8 百万个参数。
# Code reference:
# https://keras.org.cn/examples/vision/image_classification_with_vision_transformer/.
def run_experiment(model):
optimizer = keras.optimizers.AdamW(
learning_rate=learning_rate, weight_decay=weight_decay
)
model.compile(
optimizer=optimizer,
loss="sparse_categorical_crossentropy",
metrics=["accuracy"],
)
checkpoint_filepath = "/tmp/checkpoint.keras"
checkpoint_callback = keras.callbacks.ModelCheckpoint(
checkpoint_filepath,
monitor="val_accuracy",
save_best_only=True,
save_weights_only=False,
)
history = model.fit(
train_dataset,
validation_data=val_dataset,
epochs=num_epochs,
callbacks=[checkpoint_callback],
)
model.load_weights(checkpoint_filepath)
_, accuracy = model.evaluate(test_dataset)
print(f"Test accuracy: {round(accuracy * 100, 2)}%")
return history, model
conv_mixer_model = get_conv_mixer_256_8()
history, conv_mixer_model = run_experiment(conv_mixer_model)
Epoch 1/10
352/352 ━━━━━━━━━━━━━━━━━━━━ 46s 103ms/step - accuracy: 0.4594 - loss: 1.4780 - val_accuracy: 0.1536 - val_loss: 4.0766
Epoch 2/10
352/352 ━━━━━━━━━━━━━━━━━━━━ 14s 39ms/step - accuracy: 0.6996 - loss: 0.8479 - val_accuracy: 0.7240 - val_loss: 0.7926
Epoch 3/10
352/352 ━━━━━━━━━━━━━━━━━━━━ 14s 39ms/step - accuracy: 0.7823 - loss: 0.6287 - val_accuracy: 0.7800 - val_loss: 0.6532
Epoch 4/10
352/352 ━━━━━━━━━━━━━━━━━━━━ 14s 39ms/step - accuracy: 0.8264 - loss: 0.5003 - val_accuracy: 0.8074 - val_loss: 0.5895
Epoch 5/10
352/352 ━━━━━━━━━━━━━━━━━━━━ 21s 60ms/step - accuracy: 0.8605 - loss: 0.4092 - val_accuracy: 0.7996 - val_loss: 0.6037
Epoch 6/10
352/352 ━━━━━━━━━━━━━━━━━━━━ 13s 38ms/step - accuracy: 0.8788 - loss: 0.3527 - val_accuracy: 0.8072 - val_loss: 0.6162
Epoch 7/10
352/352 ━━━━━━━━━━━━━━━━━━━━ 21s 61ms/step - accuracy: 0.8972 - loss: 0.2984 - val_accuracy: 0.8226 - val_loss: 0.5604
Epoch 8/10
352/352 ━━━━━━━━━━━━━━━━━━━━ 21s 61ms/step - accuracy: 0.9087 - loss: 0.2608 - val_accuracy: 0.8310 - val_loss: 0.5303
Epoch 9/10
352/352 ━━━━━━━━━━━━━━━━━━━━ 14s 39ms/step - accuracy: 0.9176 - loss: 0.2302 - val_accuracy: 0.8458 - val_loss: 0.5051
Epoch 10/10
352/352 ━━━━━━━━━━━━━━━━━━━━ 14s 38ms/step - accuracy: 0.9336 - loss: 0.1918 - val_accuracy: 0.8316 - val_loss: 0.5848
79/79 ━━━━━━━━━━━━━━━━━━━━ 3s 32ms/step - accuracy: 0.8371 - loss: 0.5501
Test accuracy: 83.69%
可以通过使用其他正则化技术来减轻训练和验证性能之间的差距。尽管如此,能够在 10 个 epoch 内使用 0.8 百万个参数达到约 83% 的准确率是一个强大的结果。
我们可以可视化块嵌入和学习到的卷积滤波器。回想一下,每个块嵌入和中间特征图具有相同数量的通道(在本例中为 256)。这将使我们的可视化工具函数更容易实现。
# Code reference: https://bit.ly/3awIRbP.
def visualization_plot(weights, idx=1):
# First, apply min-max normalization to the
# given weights to avoid isotrophic scaling.
p_min, p_max = weights.min(), weights.max()
weights = (weights - p_min) / (p_max - p_min)
# Visualize all the filters.
num_filters = 256
plt.figure(figsize=(8, 8))
for i in range(num_filters):
current_weight = weights[:, :, :, i]
if current_weight.shape[-1] == 1:
current_weight = current_weight.squeeze()
ax = plt.subplot(16, 16, idx)
ax.set_xticks([])
ax.set_yticks([])
plt.imshow(current_weight)
idx += 1
# We first visualize the learned patch embeddings.
patch_embeddings = conv_mixer_model.layers[2].get_weights()[0]
visualization_plot(patch_embeddings)
即使我们没有将网络训练到收敛,我们也可以注意到不同的块显示出不同的模式。有些与其他块相似,而有些则非常不同。这些可视化在图像尺寸更大时更突出。
类似地,我们可以可视化原始卷积核。这可以帮助我们理解特定内核所接收的模式。
# First, print the indices of the convolution layers that are not
# pointwise convolutions.
for i, layer in enumerate(conv_mixer_model.layers):
if isinstance(layer, layers.DepthwiseConv2D):
if layer.get_config()["kernel_size"] == (5, 5):
print(i, layer)
idx = 26 # Taking a kernel from the middle of the network.
kernel = conv_mixer_model.layers[idx].get_weights()[0]
kernel = np.expand_dims(kernel.squeeze(), axis=2)
visualization_plot(kernel)
5 <DepthwiseConv2D name=depthwise_conv2d, built=True>
12 <DepthwiseConv2D name=depthwise_conv2d_1, built=True>
19 <DepthwiseConv2D name=depthwise_conv2d_2, built=True>
26 <DepthwiseConv2D name=depthwise_conv2d_3, built=True>
33 <DepthwiseConv2D name=depthwise_conv2d_4, built=True>
40 <DepthwiseConv2D name=depthwise_conv2d_5, built=True>
47 <DepthwiseConv2D name=depthwise_conv2d_6, built=True>
54 <DepthwiseConv2D name=depthwise_conv2d_7, built=True>
我们看到内核中的不同滤波器具有不同的局部范围,并且这种模式可能会随着更多训练而发展。
最近出现了一种将卷积与其他数据无关操作(如自注意力)融合的趋势。以下工作都属于这条研究路线