作者: Aritra Roy Gosthipaty,Ritwik Raha
创建日期 2023/01/25
最后修改日期 2023/02/15
描述: 使用焦点调制网络进行图像分类。
本教程旨在为焦点调制网络的实现提供一个全面的指南,如 Yang 等人 所述。
本教程将提供一种正式的、极简主义的方法来实现焦点调制网络,并探索其在深度学习领域中的潜在应用。
问题陈述
Transformer 架构 (Vaswani 等人) 已成为大多数自然语言处理任务中的事实标准,也已应用于计算机视觉领域,例如 Vision Transformer (Dosovitskiy 等人)。
在 Transformer 中,自注意力 (SA) 可以说是其成功的关键,它使输入相关的全局交互成为可能,这与将交互限制在具有共享内核的局部区域的卷积运算形成对比。
注意力 模块的数学表示如公式 1 所示。
公式 1:注意力的数学公式(来源:Aritra 和 Ritwik) |
其中
Q
是查询K
是键V
是值d_k
是键的维度对于自注意力,查询、键和值均来自输入序列。让我们将自注意力的注意力方程改写为公式 2 所示。
公式 2:自注意力的数学公式(来源:Aritra 和 Ritwik) |
查看自注意力的方程,我们可以看到它是一个二次方程。因此,随着令牌数量的增加,计算时间(成本也是)也会增加。为了缓解这个问题并使 Transformer 更具可解释性,Yang 等人尝试用更好的组件替换自注意力模块。
解决方案
Yang 等人引入了焦点调制层,作为自注意力层的无缝替代方案。该层具有很高的可解释性,使其成为深度学习从业者的宝贵工具。
在本教程中,我们将通过在 CIFAR-10 数据集上训练整个模型并直观地解释该层的性能来深入探讨该层的实际应用。
注意:我们尝试使我们的实现与官方实现 保持一致。
在本教程中,我们使用 tensorflow 版本 2.11.0
。
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.optimizers.experimental import AdamW
from typing import Optional, Tuple, List
from matplotlib import pyplot as plt
from random import randint
# Set seed for reproducibility.
tf.keras.utils.set_random_seed(42)
我们没有选择这些超参数的强有力理由。请随时更改配置并训练模型。
# DATA
TRAIN_SLICE = 40000
BUFFER_SIZE = 2048
BATCH_SIZE = 1024
AUTO = tf.data.AUTOTUNE
INPUT_SHAPE = (32, 32, 3)
IMAGE_SIZE = 48
NUM_CLASSES = 10
# OPTIMIZER
LEARNING_RATE = 1e-4
WEIGHT_DECAY = 1e-4
# TRAINING
EPOCHS = 25
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
(x_train, y_train), (x_val, y_val) = (
(x_train[:TRAIN_SLICE], y_train[:TRAIN_SLICE]),
(x_train[TRAIN_SLICE:], y_train[TRAIN_SLICE:]),
)
Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
170498071/170498071 [==============================] - 30s 0us/step
我们使用 keras.Sequential
API 将所有单独的增强步骤组合到一个 API 中。
# Build the `train` augmentation pipeline.
train_aug = keras.Sequential(
[
layers.Rescaling(1 / 255.0),
layers.Resizing(INPUT_SHAPE[0] + 20, INPUT_SHAPE[0] + 20),
layers.RandomCrop(IMAGE_SIZE, IMAGE_SIZE),
layers.RandomFlip("horizontal"),
],
name="train_data_augmentation",
)
# Build the `val` and `test` data pipeline.
test_aug = keras.Sequential(
[
layers.Rescaling(1 / 255.0),
layers.Resizing(IMAGE_SIZE, IMAGE_SIZE),
],
name="test_data_augmentation",
)
tf.data
管道train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_ds = (
train_ds.map(
lambda image, label: (train_aug(image), label), num_parallel_calls=AUTO
)
.shuffle(BUFFER_SIZE)
.batch(BATCH_SIZE)
.prefetch(AUTO)
)
val_ds = tf.data.Dataset.from_tensor_slices((x_val, y_val))
val_ds = (
val_ds.map(lambda image, label: (test_aug(image), label), num_parallel_calls=AUTO)
.batch(BATCH_SIZE)
.prefetch(AUTO)
)
test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test))
test_ds = (
test_ds.map(lambda image, label: (test_aug(image), label), num_parallel_calls=AUTO)
.batch(BATCH_SIZE)
.prefetch(AUTO)
)
WARNING:tensorflow:From /usr/local/lib/python3.8/dist-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23.
Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089
我们在这里暂停,快速查看一下焦点调制网络的架构。图 1 显示了如何将每个单独的层编译成一个模型。这使我们能够从鸟瞰的角度了解整个架构。
图 1:焦点调制模型的图示(来源:Aritra 和 Ritwik) |
在接下来的部分中,我们将深入探讨每个层的细节。我们将按照以下顺序进行
为了更好地理解以我们熟悉的方式展现的架构,让我们看看焦点调制网络在绘制为Transformer架构时会是什么样子。
图 2显示了传统Transformer架构的编码器层,其中自注意力被替换为焦点调制层。
该蓝色块代表焦点调制块。这些块的堆叠构建了一个基本层。该绿色块代表焦点调制层。
图 2: 整个架构 (来源: Aritra 和 Ritwik) |
补丁嵌入层用于将输入图像进行补丁化并将其投影到潜在空间中。该层也被用作架构中的下采样层。
class PatchEmbed(layers.Layer):
"""Image patch embedding layer, also acts as the down-sampling layer.
Args:
image_size (Tuple[int]): Input image resolution.
patch_size (Tuple[int]): Patch spatial resolution.
embed_dim (int): Embedding dimension.
"""
def __init__(
self,
image_size: Tuple[int] = (224, 224),
patch_size: Tuple[int] = (4, 4),
embed_dim: int = 96,
**kwargs,
):
super().__init__(**kwargs)
patch_resolution = [
image_size[0] // patch_size[0],
image_size[1] // patch_size[1],
]
self.image_size = image_size
self.patch_size = patch_size
self.embed_dim = embed_dim
self.patch_resolution = patch_resolution
self.num_patches = patch_resolution[0] * patch_resolution[1]
self.proj = layers.Conv2D(
filters=embed_dim, kernel_size=patch_size, strides=patch_size
)
self.flatten = layers.Reshape(target_shape=(-1, embed_dim))
self.norm = keras.layers.LayerNormalization(epsilon=1e-7)
def call(self, x: tf.Tensor) -> Tuple[tf.Tensor, int, int, int]:
"""Patchifies the image and converts into tokens.
Args:
x: Tensor of shape (B, H, W, C)
Returns:
A tuple of the processed tensor, height of the projected
feature map, width of the projected feature map, number
of channels of the projected feature map.
"""
# Project the inputs.
x = self.proj(x)
# Obtain the shape from the projected tensor.
height = tf.shape(x)[1]
width = tf.shape(x)[2]
channels = tf.shape(x)[3]
# B, H, W, C -> B, H*W, C
x = self.norm(self.flatten(x))
return x, height, width, channels
焦点调制块可以被认为是一个单独的Transformer块,其中自注意力 (SA) 模块被替换为焦点调制模块,正如我们在图 2中看到的。
让我们借助图 3来回忆一下焦点调制块应该是什么样子。
图 3: 焦点调制块的独立视图 (来源: Aritra 和 Ritwik) |
焦点调制块包含: - 多层感知器 - 焦点调制层
def MLP(
in_features: int,
hidden_features: Optional[int] = None,
out_features: Optional[int] = None,
mlp_drop_rate: float = 0.0,
):
hidden_features = hidden_features or in_features
out_features = out_features or in_features
return keras.Sequential(
[
layers.Dense(units=hidden_features, activation=keras.activations.gelu),
layers.Dense(units=out_features),
layers.Dropout(rate=mlp_drop_rate),
]
)
在典型的Transformer架构中,对于输入特征图X in R^{HxWxC}
中的每个视觉标记 (查询)x_i in R^C
,通用编码过程会生成特征表示y_i in R^C
。
编码过程包括交互(例如,与周围环境进行点积)和聚合(例如,对上下文进行加权平均)。
我们将讨论两种类型的编码:- 自注意力中的交互然后聚合 - 焦点调制中的聚合然后交互
自注意力
图 4: 自注意力模块。 (来源: Aritra 和 Ritwik) |
等式 3: 自注意力中的聚合和交互 (来源: Aritra 和 Ritwik) |
如图 4所示,查询和键相互交互(在交互步骤中)以输出注意力分数。值的加权聚合随后进行,被称为聚合步骤。
焦点调制
图 5: 焦点调制模块。 (来源: Aritra 和 Ritwik) |
等式 4: 焦点调制中的聚合和交互 (来源: Aritra 和 Ritwik) |
图 5描述了焦点调制层。q()
是查询投影函数。它是一个线性层,将查询投影到潜在空间中。m ()
是上下文聚合函数。与自注意力不同,聚合步骤在焦点调制中发生在交互步骤之前。
虽然q()
很容易理解,但上下文聚合函数m()
更复杂。因此,本节将重点关注m()
。
图 6: 上下文聚合函数m() 。 (来源: Aritra 和 Ritwik) |
上下文聚合函数m()
由两部分组成,如图 6所示:- 分层上下文化 - 门控聚合
图 7: 分层上下文化 (来源: Aritra 和 Ritwik) |
在图 7中,我们看到输入首先进行线性投影。这种线性投影产生了Z^0
。其中Z^0
可以表示为:
等式 5: Z^0 的线性投影 (来源: Aritra 和 Ritwik) |
Z^0
然后传递到一系列深度方向 (DWConv) 卷积和GeLU层。作者将每个 DWConv 和 GeLU 块称为由l
表示的层级。在图 6中,我们有两个层级。在数学上,这表示为:
等式 6: 调制层的层级 (来源: Aritra 和 Ritwik) |
其中l in {1, ... , L}
最终的特征图经过全局平均池化层。这可以表示为:
等式 7: 最终特征的平均池化 (来源: Aritra 和 Ritwik) |
图 8: 门控聚合 (来源: Aritra 和 Ritwik) |
现在,由于分层上下文化步骤,我们有了L+1
个中间特征图,我们需要一种门控机制,让某些特征通过,而阻止其他特征通过。这可以通过注意力模块实现。在本教程的后面,我们将可视化这些门,以更好地理解它们的用处。
首先,我们构建聚合的权重。在这里,我们对输入特征图应用线性层,将其投影到L+1
维。
等式 8: 门 (来源: Aritra 和 Ritwik) |
接下来,我们对上下文进行加权聚合。
等式 9: 最终特征图 (来源: Aritra 和 Ritwik) |
为了实现不同通道之间的通信,我们使用另一个线性层h()
来获得调制器
等式 10: 调制器 (来源: Aritra 和 Ritwik) |
总结一下焦点调制层,我们有:
等式 11: 焦点调制层 (来源: Aritra 和 Ritwik) |
class FocalModulationLayer(layers.Layer):
"""The Focal Modulation layer includes query projection & context aggregation.
Args:
dim (int): Projection dimension.
focal_window (int): Window size for focal modulation.
focal_level (int): The current focal level.
focal_factor (int): Factor of focal modulation.
proj_drop_rate (float): Rate of dropout.
"""
def __init__(
self,
dim: int,
focal_window: int,
focal_level: int,
focal_factor: int = 2,
proj_drop_rate: float = 0.0,
**kwargs,
):
super().__init__(**kwargs)
self.dim = dim
self.focal_window = focal_window
self.focal_level = focal_level
self.focal_factor = focal_factor
self.proj_drop_rate = proj_drop_rate
# Project the input feature into a new feature space using a
# linear layer. Note the `units` used. We will be projecting the input
# feature all at once and split the projection into query, context,
# and gates.
self.initial_proj = layers.Dense(
units=(2 * self.dim) + (self.focal_level + 1),
use_bias=True,
)
self.focal_layers = list()
self.kernel_sizes = list()
for idx in range(self.focal_level):
kernel_size = (self.focal_factor * idx) + self.focal_window
depth_gelu_block = keras.Sequential(
[
layers.ZeroPadding2D(padding=(kernel_size // 2, kernel_size // 2)),
layers.Conv2D(
filters=self.dim,
kernel_size=kernel_size,
activation=keras.activations.gelu,
groups=self.dim,
use_bias=False,
),
]
)
self.focal_layers.append(depth_gelu_block)
self.kernel_sizes.append(kernel_size)
self.activation = keras.activations.gelu
self.gap = layers.GlobalAveragePooling2D(keepdims=True)
self.modulator_proj = layers.Conv2D(
filters=self.dim,
kernel_size=(1, 1),
use_bias=True,
)
self.proj = layers.Dense(units=self.dim)
self.proj_drop = layers.Dropout(self.proj_drop_rate)
def call(self, x: tf.Tensor, training: Optional[bool] = None) -> tf.Tensor:
"""Forward pass of the layer.
Args:
x: Tensor of shape (B, H, W, C)
"""
# Apply the linear projecion to the input feature map
x_proj = self.initial_proj(x)
# Split the projected x into query, context and gates
query, context, self.gates = tf.split(
value=x_proj,
num_or_size_splits=[self.dim, self.dim, self.focal_level + 1],
axis=-1,
)
# Context aggregation
context = self.focal_layers[0](context)
context_all = context * self.gates[..., 0:1]
for idx in range(1, self.focal_level):
context = self.focal_layers[idx](context)
context_all += context * self.gates[..., idx : idx + 1]
# Build the global context
context_global = self.activation(self.gap(context))
context_all += context_global * self.gates[..., self.focal_level :]
# Focal Modulation
self.modulator = self.modulator_proj(context_all)
x_output = query * self.modulator
# Project the output and apply dropout
x_output = self.proj(x_output)
x_output = self.proj_drop(x_output)
return x_output
最后,我们拥有构建焦点调制块所需的所有组件。在这里,我们将 MLP 和焦点调制层结合在一起,构建焦点调制块。
class FocalModulationBlock(layers.Layer):
"""Combine FFN and Focal Modulation Layer.
Args:
dim (int): Number of input channels.
input_resolution (Tuple[int]): Input resulotion.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
drop (float): Dropout rate.
drop_path (float): Stochastic depth rate.
focal_level (int): Number of focal levels.
focal_window (int): Focal window size at first focal level
"""
def __init__(
self,
dim: int,
input_resolution: Tuple[int],
mlp_ratio: float = 4.0,
drop: float = 0.0,
drop_path: float = 0.0,
focal_level: int = 1,
focal_window: int = 3,
**kwargs,
):
super().__init__(**kwargs)
self.dim = dim
self.input_resolution = input_resolution
self.mlp_ratio = mlp_ratio
self.focal_level = focal_level
self.focal_window = focal_window
self.norm = layers.LayerNormalization(epsilon=1e-5)
self.modulation = FocalModulationLayer(
dim=self.dim,
focal_window=self.focal_window,
focal_level=self.focal_level,
proj_drop_rate=drop,
)
mlp_hidden_dim = int(self.dim * self.mlp_ratio)
self.mlp = MLP(
in_features=self.dim,
hidden_features=mlp_hidden_dim,
mlp_drop_rate=drop,
)
def call(self, x: tf.Tensor, height: int, width: int, channels: int) -> tf.Tensor:
"""Processes the input tensor through the focal modulation block.
Args:
x (tf.Tensor): Inputs of the shape (B, L, C)
height (int): The height of the feature map
width (int): The width of the feature map
channels (int): The number of channels of the feature map
Returns:
The processed tensor.
"""
shortcut = x
# Focal Modulation
x = tf.reshape(x, shape=(-1, height, width, channels))
x = self.modulation(x)
x = tf.reshape(x, shape=(-1, height * width, channels))
# FFN
x = shortcut + x
x = x + self.mlp(self.norm(x))
return x
基本层由多个焦点调制块组成。这在图 9中得到了说明。
图 9: 基本层,由多个焦点调制块组成。 (来源: Aritra 和 Ritwik) |
注意在图 9中如何用Nx
表示多个焦点调制块。这说明了基本层是如何由多个焦点调制块组成的。
class BasicLayer(layers.Layer):
"""Collection of Focal Modulation Blocks.
Args:
dim (int): Dimensions of the model.
out_dim (int): Dimension used by the Patch Embedding Layer.
input_resolution (Tuple[int]): Input image resolution.
depth (int): The number of Focal Modulation Blocks.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
drop (float): Dropout rate.
downsample (tf.keras.layers.Layer): Downsampling layer at the end of the layer.
focal_level (int): The current focal level.
focal_window (int): Focal window used.
"""
def __init__(
self,
dim: int,
out_dim: int,
input_resolution: Tuple[int],
depth: int,
mlp_ratio: float = 4.0,
drop: float = 0.0,
downsample=None,
focal_level: int = 1,
focal_window: int = 1,
**kwargs,
):
super().__init__(**kwargs)
self.dim = dim
self.input_resolution = input_resolution
self.depth = depth
self.blocks = [
FocalModulationBlock(
dim=dim,
input_resolution=input_resolution,
mlp_ratio=mlp_ratio,
drop=drop,
focal_level=focal_level,
focal_window=focal_window,
)
for i in range(self.depth)
]
# Downsample layer at the end of the layer
if downsample is not None:
self.downsample = downsample(
image_size=input_resolution,
patch_size=(2, 2),
embed_dim=out_dim,
)
else:
self.downsample = None
def call(
self, x: tf.Tensor, height: int, width: int, channels: int
) -> Tuple[tf.Tensor, int, int, int]:
"""Forward pass of the layer.
Args:
x (tf.Tensor): Tensor of shape (B, L, C)
height (int): Height of feature map
width (int): Width of feature map
channels (int): Embed Dim of feature map
Returns:
A tuple of the processed tensor, changed height, width, and
dim of the tensor.
"""
# Apply Focal Modulation Blocks
for block in self.blocks:
x = block(x, height, width, channels)
# Except the last Basic Layer, all the layers have
# downsample at the end of it.
if self.downsample is not None:
x = tf.reshape(x, shape=(-1, height, width, channels))
x, height_o, width_o, channels_o = self.downsample(x)
else:
height_o, width_o, channels_o = height, width, channels
return x, height_o, width_o, channels_o
这是将所有内容整合在一起的模型。它由多个基本层和一个分类头组成。要回顾一下它的结构,请参考图 1。
class FocalModulationNetwork(keras.Model):
"""The Focal Modulation Network.
Parameters:
image_size (Tuple[int]): Spatial size of images used.
patch_size (Tuple[int]): Patch size of each patch.
num_classes (int): Number of classes used for classification.
embed_dim (int): Patch embedding dimension.
depths (List[int]): Depth of each Focal Transformer block.
mlp_ratio (float): Ratio of expansion for the intermediate layer of MLP.
drop_rate (float): The dropout rate for FM and MLP layers.
focal_levels (list): How many focal levels at all stages.
Note that this excludes the finest-grain level.
focal_windows (list): The focal window size at all stages.
"""
def __init__(
self,
image_size: Tuple[int] = (48, 48),
patch_size: Tuple[int] = (4, 4),
num_classes: int = 10,
embed_dim: int = 256,
depths: List[int] = [2, 3, 2],
mlp_ratio: float = 4.0,
drop_rate: float = 0.1,
focal_levels=[2, 2, 2],
focal_windows=[3, 3, 3],
**kwargs,
):
super().__init__(**kwargs)
self.num_layers = len(depths)
embed_dim = [embed_dim * (2**i) for i in range(self.num_layers)]
self.num_classes = num_classes
self.embed_dim = embed_dim
self.num_features = embed_dim[-1]
self.mlp_ratio = mlp_ratio
self.patch_embed = PatchEmbed(
image_size=image_size,
patch_size=patch_size,
embed_dim=embed_dim[0],
)
num_patches = self.patch_embed.num_patches
patches_resolution = self.patch_embed.patch_resolution
self.patches_resolution = patches_resolution
self.pos_drop = layers.Dropout(drop_rate)
self.basic_layers = list()
for i_layer in range(self.num_layers):
layer = BasicLayer(
dim=embed_dim[i_layer],
out_dim=embed_dim[i_layer + 1]
if (i_layer < self.num_layers - 1)
else None,
input_resolution=(
patches_resolution[0] // (2**i_layer),
patches_resolution[1] // (2**i_layer),
),
depth=depths[i_layer],
mlp_ratio=self.mlp_ratio,
drop=drop_rate,
downsample=PatchEmbed if (i_layer < self.num_layers - 1) else None,
focal_level=focal_levels[i_layer],
focal_window=focal_windows[i_layer],
)
self.basic_layers.append(layer)
self.norm = keras.layers.LayerNormalization(epsilon=1e-7)
self.avgpool = layers.GlobalAveragePooling1D()
self.flatten = layers.Flatten()
self.head = layers.Dense(self.num_classes, activation="softmax")
def call(self, x: tf.Tensor) -> tf.Tensor:
"""Forward pass of the layer.
Args:
x: Tensor of shape (B, H, W, C)
Returns:
The logits.
"""
# Patch Embed the input images.
x, height, width, channels = self.patch_embed(x)
x = self.pos_drop(x)
for idx, layer in enumerate(self.basic_layers):
x, height, width, channels = layer(x, height, width, channels)
x = self.norm(x)
x = self.avgpool(x)
x = self.flatten(x)
x = self.head(x)
return x
现在所有组件都已到位,并且架构实际构建完成,我们已准备好将其投入使用。
在本节中,我们将在 CIFAR-10 数据集上训练我们的焦点调制模型。
焦点调制网络的一个关键特征是显式的输入依赖性。这意味着调制器是通过查看目标位置周围的局部特征来计算的,因此它取决于输入。简单来说,这使得解释变得容易。我们可以简单地将门控值和原始图像并排放置,以查看门控机制的工作原理。
论文的作者可视化了门和调制器,以关注焦点调制层的可解释性。以下是一个可视化回调,它在模型训练时显示模型中特定层的门和调制器。
我们稍后会注意到,随着模型训练,可视化效果会更好。
门似乎选择性地允许输入图像的某些方面通过,而温和地忽略其他方面,最终导致分类精度提高。
def display_grid(
test_images: tf.Tensor,
gates: tf.Tensor,
modulator: tf.Tensor,
):
"""Displays the image with the gates and modulator overlayed.
Args:
test_images (tf.Tensor): A batch of test images.
gates (tf.Tensor): The gates of the Focal Modualtion Layer.
modulator (tf.Tensor): The modulator of the Focal Modulation Layer.
"""
fig, ax = plt.subplots(nrows=1, ncols=5, figsize=(25, 5))
# Radomly sample an image from the batch.
index = randint(0, BATCH_SIZE - 1)
orig_image = test_images[index]
gate_image = gates[index]
modulator_image = modulator[index]
# Original Image
ax[0].imshow(orig_image)
ax[0].set_title("Original:")
ax[0].axis("off")
for index in range(1, 5):
img = ax[index].imshow(orig_image)
if index != 4:
overlay_image = gate_image[..., index - 1]
title = f"G {index}:"
else:
overlay_image = tf.norm(modulator_image, ord=2, axis=-1)
title = f"MOD:"
ax[index].imshow(
overlay_image, cmap="inferno", alpha=0.6, extent=img.get_extent()
)
ax[index].set_title(title)
ax[index].axis("off")
plt.axis("off")
plt.show()
plt.close()
# Taking a batch of test inputs to measure the model's progress.
test_images, test_labels = next(iter(test_ds))
upsampler = tf.keras.layers.UpSampling2D(
size=(4, 4),
interpolation="bilinear",
)
class TrainMonitor(keras.callbacks.Callback):
def __init__(self, epoch_interval=None):
self.epoch_interval = epoch_interval
def on_epoch_end(self, epoch, logs=None):
if self.epoch_interval and epoch % self.epoch_interval == 0:
_ = self.model(test_images)
# Take the mid layer for visualization
gates = self.model.basic_layers[1].blocks[-1].modulation.gates
gates = upsampler(gates)
modulator = self.model.basic_layers[1].blocks[-1].modulation.modulator
modulator = upsampler(modulator)
# Display the grid of gates and modulator.
display_grid(test_images=test_images, gates=gates, modulator=modulator)
# Some code is taken from:
# https://www.kaggle.com/ashusma/training-rfcx-tensorflow-tpu-effnet-b2.
class WarmUpCosine(keras.optimizers.schedules.LearningRateSchedule):
def __init__(
self, learning_rate_base, total_steps, warmup_learning_rate, warmup_steps
):
super().__init__()
self.learning_rate_base = learning_rate_base
self.total_steps = total_steps
self.warmup_learning_rate = warmup_learning_rate
self.warmup_steps = warmup_steps
self.pi = tf.constant(np.pi)
def __call__(self, step):
if self.total_steps < self.warmup_steps:
raise ValueError("Total_steps must be larger or equal to warmup_steps.")
cos_annealed_lr = tf.cos(
self.pi
* (tf.cast(step, tf.float32) - self.warmup_steps)
/ float(self.total_steps - self.warmup_steps)
)
learning_rate = 0.5 * self.learning_rate_base * (1 + cos_annealed_lr)
if self.warmup_steps > 0:
if self.learning_rate_base < self.warmup_learning_rate:
raise ValueError(
"Learning_rate_base must be larger or equal to "
"warmup_learning_rate."
)
slope = (
self.learning_rate_base - self.warmup_learning_rate
) / self.warmup_steps
warmup_rate = slope * tf.cast(step, tf.float32) + self.warmup_learning_rate
learning_rate = tf.where(
step < self.warmup_steps, warmup_rate, learning_rate
)
return tf.where(
step > self.total_steps, 0.0, learning_rate, name="learning_rate"
)
total_steps = int((len(x_train) / BATCH_SIZE) * EPOCHS)
warmup_epoch_percentage = 0.15
warmup_steps = int(total_steps * warmup_epoch_percentage)
scheduled_lrs = WarmUpCosine(
learning_rate_base=LEARNING_RATE,
total_steps=total_steps,
warmup_learning_rate=0.0,
warmup_steps=warmup_steps,
)
focal_mod_net = FocalModulationNetwork()
optimizer = AdamW(learning_rate=scheduled_lrs, weight_decay=WEIGHT_DECAY)
# Compile and train the model.
focal_mod_net.compile(
optimizer=optimizer,
loss="sparse_categorical_crossentropy",
metrics=["accuracy"],
)
history = focal_mod_net.fit(
train_ds,
epochs=EPOCHS,
validation_data=val_ds,
callbacks=[TrainMonitor(epoch_interval=10)],
)
Epoch 1/25
40/40 [==============================] - ETA: 0s - loss: 2.3925 - accuracy: 0.1401
40/40 [==============================] - 57s 724ms/step - loss: 2.3925 - accuracy: 0.1401 - val_loss: 2.2182 - val_accuracy: 0.1768
Epoch 2/25
40/40 [==============================] - 20s 483ms/step - loss: 2.0790 - accuracy: 0.2261 - val_loss: 2.2933 - val_accuracy: 0.1795
Epoch 3/25
40/40 [==============================] - 19s 479ms/step - loss: 2.0130 - accuracy: 0.2585 - val_loss: 2.6833 - val_accuracy: 0.2022
Epoch 4/25
40/40 [==============================] - 21s 507ms/step - loss: 1.8270 - accuracy: 0.3315 - val_loss: 1.9127 - val_accuracy: 0.3215
Epoch 5/25
40/40 [==============================] - 19s 475ms/step - loss: 1.6037 - accuracy: 0.4173 - val_loss: 1.7226 - val_accuracy: 0.3938
Epoch 6/25
40/40 [==============================] - 19s 476ms/step - loss: 1.4758 - accuracy: 0.4658 - val_loss: 1.5097 - val_accuracy: 0.4733
Epoch 7/25
40/40 [==============================] - 19s 476ms/step - loss: 1.3677 - accuracy: 0.5075 - val_loss: 1.4630 - val_accuracy: 0.4986
Epoch 8/25
40/40 [==============================] - 21s 508ms/step - loss: 1.2599 - accuracy: 0.5490 - val_loss: 1.2908 - val_accuracy: 0.5492
Epoch 9/25
40/40 [==============================] - 19s 478ms/step - loss: 1.1689 - accuracy: 0.5818 - val_loss: 1.2750 - val_accuracy: 0.5518
Epoch 10/25
40/40 [==============================] - 19s 476ms/step - loss: 1.0843 - accuracy: 0.6140 - val_loss: 1.1444 - val_accuracy: 0.6002
Epoch 11/25
39/40 [============================>.] - ETA: 0s - loss: 1.0040 - accuracy: 0.6453
40/40 [==============================] - 20s 489ms/step - loss: 1.0041 - accuracy: 0.6452 - val_loss: 1.1765 - val_accuracy: 0.5939
Epoch 12/25
40/40 [==============================] - 20s 480ms/step - loss: 0.9401 - accuracy: 0.6701 - val_loss: 1.1276 - val_accuracy: 0.6181
Epoch 13/25
40/40 [==============================] - 19s 480ms/step - loss: 0.8787 - accuracy: 0.6910 - val_loss: 0.9990 - val_accuracy: 0.6547
Epoch 14/25
40/40 [==============================] - 19s 479ms/step - loss: 0.8198 - accuracy: 0.7122 - val_loss: 1.0074 - val_accuracy: 0.6562
Epoch 15/25
40/40 [==============================] - 19s 480ms/step - loss: 0.7831 - accuracy: 0.7275 - val_loss: 0.9739 - val_accuracy: 0.6686
Epoch 16/25
40/40 [==============================] - 19s 478ms/step - loss: 0.7358 - accuracy: 0.7428 - val_loss: 0.9578 - val_accuracy: 0.6753
Epoch 17/25
40/40 [==============================] - 19s 478ms/step - loss: 0.7018 - accuracy: 0.7557 - val_loss: 0.9414 - val_accuracy: 0.6789
Epoch 18/25
40/40 [==============================] - 20s 480ms/step - loss: 0.6678 - accuracy: 0.7678 - val_loss: 0.9492 - val_accuracy: 0.6771
Epoch 19/25
40/40 [==============================] - 19s 476ms/step - loss: 0.6423 - accuracy: 0.7783 - val_loss: 0.9422 - val_accuracy: 0.6832
Epoch 20/25
40/40 [==============================] - 19s 479ms/step - loss: 0.6202 - accuracy: 0.7868 - val_loss: 0.9324 - val_accuracy: 0.6860
Epoch 21/25
40/40 [==============================] - ETA: 0s - loss: 0.6005 - accuracy: 0.7938
40/40 [==============================] - 20s 488ms/step - loss: 0.6005 - accuracy: 0.7938 - val_loss: 0.9326 - val_accuracy: 0.6880
Epoch 22/25
40/40 [==============================] - 19s 478ms/step - loss: 0.5937 - accuracy: 0.7970 - val_loss: 0.9339 - val_accuracy: 0.6875
Epoch 23/25
40/40 [==============================] - 19s 478ms/step - loss: 0.5899 - accuracy: 0.7984 - val_loss: 0.9294 - val_accuracy: 0.6894
Epoch 24/25
40/40 [==============================] - 19s 478ms/step - loss: 0.5840 - accuracy: 0.8012 - val_loss: 0.9315 - val_accuracy: 0.6881
Epoch 25/25
40/40 [==============================] - 19s 478ms/step - loss: 0.5853 - accuracy: 0.7997 - val_loss: 0.9315 - val_accuracy: 0.6880
plt.plot(history.history["loss"], label="loss")
plt.plot(history.history["val_loss"], label="val_loss")
plt.legend()
plt.show()
plt.plot(history.history["accuracy"], label="accuracy")
plt.plot(history.history["val_accuracy"], label="val_accuracy")
plt.legend()
plt.show()
让我们在一些测试图像上测试我们的模型,看看门看起来是什么样子。
test_images, test_labels = next(iter(test_ds))
_ = focal_mod_net(test_images)
# Take the mid layer for visualization
gates = focal_mod_net.basic_layers[1].blocks[-1].modulation.gates
gates = upsampler(gates)
modulator = focal_mod_net.basic_layers[1].blocks[-1].modulation.modulator
modulator = upsampler(modulator)
# Plot the test images with the gates and modulator overlayed.
for row in range(5):
display_grid(
test_images=test_images,
gates=gates,
modulator=modulator,
)
提出的架构,焦点调制网络架构,是一种机制,它允许图像的不同部分以一种依赖于图像本身的方式相互交互。它的工作原理是首先收集图像每个部分(“查询标记”)周围的不同层级的上下文信息,然后使用门来决定哪些上下文信息最相关,最后以一种简单但有效的方式组合选定的信息。
这旨在替代 Transformer 架构中的自注意力机制。使这项研究引人注目的关键特征不是无注意力网络的概念,而是引入了同样强大的可解释的架构。
作者还提到,他们创建了一系列焦点调制网络 (FocalNets),它们显著优于自注意力对应网络,并且参数和预训练数据量更少。
FocalNets 架构有可能提供令人印象深刻的结果,并且提供简单的实现。它有希望的性能和易用性使其成为研究人员在自己的项目中探索的替代自注意力的有吸引力的选择。它有可能在不久的将来被深度学习社区广泛采用。
我们要感谢PyImageSearch提供 Colab Pro 帐户,JarvisLabs.ai提供 GPU 积分,以及微软研究院提供其论文的官方实现。我们还要感谢论文的第一作者Jianwei Yang,他广泛审阅了本教程。