作者: Aritra Roy Gosthipaty, Ritwik Raha
创建日期 2023/01/25
最后修改日期 2023/02/15
描述: 使用焦点调制网络进行图像分类。
本教程旨在提供一份关于焦点调制网络实现的全面指南,如 Yang 等人 所述。
本教程将提供一种形式化、极简主义的方法来实现焦点调制网络,并探索其在深度学习领域的潜在应用。
问题陈述
Transformer 架构(Vaswani 等人)已成为大多数自然语言处理任务中事实上的标准,并已应用于计算机视觉领域,例如 Vision Transformers(Dosovitskiy 等人)。
在 Transformer 中,自注意力(SA)可以说是其成功的关键,它实现了与输入相关的全局交互,这与卷积操作形成对比,后者将交互限制在具有共享核的局部区域内。
注意力模块的数学表示如公式 1 所示。
![]() |
---|
公式 1:注意力的数学公式(来源:Aritra 和 Ritwik) |
其中
Q
是查询K
是键V
是值d_k
是键的维度对于自注意力,查询、键和值都来源于输入序列。让我们将自注意力的注意力方程重写,如公式 2 所示。
![]() |
---|
公式 2:自注意力机制的数学公式(来源:Aritra 和 Ritwik) |
观察自注意力公式,我们发现它是一个二次方程。因此,随着 token 数量的增加,计算时间(以及成本)也会增加。为了缓解这个问题并使 Transformer 更具可解释性,Yang 等人试图用更好的组件替代自注意力模块。
解决方案
Yang 等人引入了焦点调制层作为自注意力层的无缝替代品。该层具有高度可解释性,使其成为深度学习从业者的宝贵工具。
在本教程中,我们将深入探讨该层的实际应用,通过在 CIFAR-10 数据集上训练整个模型,并直观地解释该层的性能。
注意:我们尽量使我们的实现与 官方实现保持一致。
本教程使用 tensorflow 版本 2.11.0
。
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.optimizers.experimental import AdamW
from typing import Optional, Tuple, List
from matplotlib import pyplot as plt
from random import randint
# Set seed for reproducibility.
tf.keras.utils.set_random_seed(42)
我们选择这些超参数并没有强烈的理由。请随意更改配置并训练模型。
# DATA
TRAIN_SLICE = 40000
BUFFER_SIZE = 2048
BATCH_SIZE = 1024
AUTO = tf.data.AUTOTUNE
INPUT_SHAPE = (32, 32, 3)
IMAGE_SIZE = 48
NUM_CLASSES = 10
# OPTIMIZER
LEARNING_RATE = 1e-4
WEIGHT_DECAY = 1e-4
# TRAINING
EPOCHS = 25
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
(x_train, y_train), (x_val, y_val) = (
(x_train[:TRAIN_SLICE], y_train[:TRAIN_SLICE]),
(x_train[TRAIN_SLICE:], y_train[TRAIN_SLICE:]),
)
Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
170498071/170498071 [==============================] - 30s 0us/step
我们使用 keras.Sequential
API 将所有单独的增强步骤组合成一个 API。
# Build the `train` augmentation pipeline.
train_aug = keras.Sequential(
[
layers.Rescaling(1 / 255.0),
layers.Resizing(INPUT_SHAPE[0] + 20, INPUT_SHAPE[0] + 20),
layers.RandomCrop(IMAGE_SIZE, IMAGE_SIZE),
layers.RandomFlip("horizontal"),
],
name="train_data_augmentation",
)
# Build the `val` and `test` data pipeline.
test_aug = keras.Sequential(
[
layers.Rescaling(1 / 255.0),
layers.Resizing(IMAGE_SIZE, IMAGE_SIZE),
],
name="test_data_augmentation",
)
tf.data
管道train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_ds = (
train_ds.map(
lambda image, label: (train_aug(image), label), num_parallel_calls=AUTO
)
.shuffle(BUFFER_SIZE)
.batch(BATCH_SIZE)
.prefetch(AUTO)
)
val_ds = tf.data.Dataset.from_tensor_slices((x_val, y_val))
val_ds = (
val_ds.map(lambda image, label: (test_aug(image), label), num_parallel_calls=AUTO)
.batch(BATCH_SIZE)
.prefetch(AUTO)
)
test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test))
test_ds = (
test_ds.map(lambda image, label: (test_aug(image), label), num_parallel_calls=AUTO)
.batch(BATCH_SIZE)
.prefetch(AUTO)
)
WARNING:tensorflow:From /usr/local/lib/python3.8/dist-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23.
Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089
我们在这里暂停一下,快速了解焦点调制网络的架构。图 1 展示了每个独立层如何编译成一个单一模型。这让我们对整个架构有一个鸟瞰图。
![]() |
---|
图 1:焦点调制模型图(来源:Aritra 和 Ritwik) |
我们将在以下部分深入探讨这些层。我们将遵循以下顺序
为了更好地理解我们熟悉的架构格式,让我们看看焦点调制网络在绘制成 Transformer 架构时会是什么样子。
图 2 显示了传统 Transformer 架构的编码器层,其中自注意力被焦点调制层取代。
蓝色的蓝色块代表焦点调制块。这些块的堆叠构成一个基本层。绿色的绿色块代表焦点调制层。
![]() |
---|
图 2:整体架构(来源:Aritra 和 Ritwik) |
补丁嵌入层用于将输入图像打补丁并将其投影到潜在空间。该层在架构中也用作下采样层。
class PatchEmbed(layers.Layer):
"""Image patch embedding layer, also acts as the down-sampling layer.
Args:
image_size (Tuple[int]): Input image resolution.
patch_size (Tuple[int]): Patch spatial resolution.
embed_dim (int): Embedding dimension.
"""
def __init__(
self,
image_size: Tuple[int] = (224, 224),
patch_size: Tuple[int] = (4, 4),
embed_dim: int = 96,
**kwargs,
):
super().__init__(**kwargs)
patch_resolution = [
image_size[0] // patch_size[0],
image_size[1] // patch_size[1],
]
self.image_size = image_size
self.patch_size = patch_size
self.embed_dim = embed_dim
self.patch_resolution = patch_resolution
self.num_patches = patch_resolution[0] * patch_resolution[1]
self.proj = layers.Conv2D(
filters=embed_dim, kernel_size=patch_size, strides=patch_size
)
self.flatten = layers.Reshape(target_shape=(-1, embed_dim))
self.norm = keras.layers.LayerNormalization(epsilon=1e-7)
def call(self, x: tf.Tensor) -> Tuple[tf.Tensor, int, int, int]:
"""Patchifies the image and converts into tokens.
Args:
x: Tensor of shape (B, H, W, C)
Returns:
A tuple of the processed tensor, height of the projected
feature map, width of the projected feature map, number
of channels of the projected feature map.
"""
# Project the inputs.
x = self.proj(x)
# Obtain the shape from the projected tensor.
height = tf.shape(x)[1]
width = tf.shape(x)[2]
channels = tf.shape(x)[3]
# B, H, W, C -> B, H*W, C
x = self.norm(self.flatten(x))
return x, height, width, channels
焦点调制块可以被视为一个 Transformer 块,其中自注意力 (SA) 模块被焦点调制模块取代,正如我们在图 2 中所见。
让我们借助图 3 回忆一下焦点调制块应该是什么样子。
![]() |
---|
图 3:焦点调制块的独立视图(来源:Aritra 和 Ritwik) |
焦点调制块包括:- 多层感知器 - 焦点调制层
def MLP(
in_features: int,
hidden_features: Optional[int] = None,
out_features: Optional[int] = None,
mlp_drop_rate: float = 0.0,
):
hidden_features = hidden_features or in_features
out_features = out_features or in_features
return keras.Sequential(
[
layers.Dense(units=hidden_features, activation=keras.activations.gelu),
layers.Dense(units=out_features),
layers.Dropout(rate=mlp_drop_rate),
]
)
在典型的 Transformer 架构中,对于输入特征图 X in R^{HxWxC}
中的每个视觉 token (查询) x_i in R^C
,一个通用编码过程会生成一个特征表示 y_i in R^C
。
编码过程包括交互(与其周围环境,例如点积)和聚合(在上下文中,例如加权平均)。
我们将在这里讨论两种类型的编码:- 自注意力中的交互然后聚合 - 焦点调制中的聚合然后交互
自注意力
![]() |
---|
图 4:自注意力模块。(来源:Aritra 和 Ritwik) |
![]() |
---|
公式 3:自注意力中的聚合与交互(来源:Aritra 和 Ritwik) |
如图 4 所示,查询和键相互作用(在交互步骤中)以输出注意力分数。接下来是值的加权聚合,称为聚合步骤。
焦点调制
![]() |
---|
图 5:焦点调制模块。(来源:Aritra 和 Ritwik) |
![]() |
---|
公式 4:焦点调制中的聚合和交互(来源:Aritra 和 Ritwik) |
图 5 描绘了焦点调制层。q()
是查询投影函数。它是一个将查询投影到潜在空间的线性层。m()
是上下文聚合函数。与自注意力不同,聚合步骤在焦点调制中发生在交互步骤之前。
虽然 q()
很容易理解,但上下文聚合函数 m()
更复杂。因此,本节将重点关注 m()
。
![]() |
---|
图 6:上下文聚合函数 m() 。(来源:Aritra 和 Ritwik) |
上下文聚合函数 m()
由两部分组成,如图 6 所示:- 分层上下文化 - 门控聚合
![]() |
---|
图 7:分层上下文化(来源:Aritra 和 Ritwik) |
在图 7 中,我们看到输入首先被线性投影。这种线性投影产生 Z^0
。其中 Z^0
可以表示如下
![]() |
---|
公式 5:Z^0 的线性投影(来源:Aritra 和 Ritwik) |
Z^0
接着被传递给一系列深度卷积 (DWConv) 和 GeLU 层。作者将 DWConv 和 GeLU 的每个块称为级别,用 l
表示。在图 6 中,我们有两个级别。数学上表示为
![]() |
---|
公式 6:调制层的级别(来源:Aritra 和 Ritwik) |
其中 l in {1, ... , L}
最终的特征图通过一个全局平均池化层。这可以表示为
![]() |
---|
公式 7:最终特征的平均池化(来源:Aritra 和 Ritwik) |
![]() |
---|
图 8:门控聚合(来源:Aritra 和 Ritwik) |
现在我们通过分层上下文化步骤获得了 L+1
个中间特征图,我们需要一个门控机制,让一些特征通过并阻止其他特征。这可以通过注意力模块实现。在本教程的后面,我们将可视化这些门控,以更好地理解它们的用途。
首先,我们构建聚合的权重。这里我们对输入特征图应用一个线性层,将其投影到 L+1
维度。
![]() |
---|
公式 8:门控(来源:Aritra 和 Ritwik) |
接下来,我们对上下文执行加权聚合。
![]() |
---|
公式 9:最终特征图(来源:Aritra 和 Ritwik) |
为了实现不同通道之间的通信,我们使用另一个线性层 h()
来获得调制器。
![]() |
---|
公式 10:调制器(来源:Aritra 和 Ritwik) |
总结焦点调制层,我们有
![]() |
---|
公式 11:焦点调制层(来源:Aritra 和 Ritwik) |
class FocalModulationLayer(layers.Layer):
"""The Focal Modulation layer includes query projection & context aggregation.
Args:
dim (int): Projection dimension.
focal_window (int): Window size for focal modulation.
focal_level (int): The current focal level.
focal_factor (int): Factor of focal modulation.
proj_drop_rate (float): Rate of dropout.
"""
def __init__(
self,
dim: int,
focal_window: int,
focal_level: int,
focal_factor: int = 2,
proj_drop_rate: float = 0.0,
**kwargs,
):
super().__init__(**kwargs)
self.dim = dim
self.focal_window = focal_window
self.focal_level = focal_level
self.focal_factor = focal_factor
self.proj_drop_rate = proj_drop_rate
# Project the input feature into a new feature space using a
# linear layer. Note the `units` used. We will be projecting the input
# feature all at once and split the projection into query, context,
# and gates.
self.initial_proj = layers.Dense(
units=(2 * self.dim) + (self.focal_level + 1),
use_bias=True,
)
self.focal_layers = list()
self.kernel_sizes = list()
for idx in range(self.focal_level):
kernel_size = (self.focal_factor * idx) + self.focal_window
depth_gelu_block = keras.Sequential(
[
layers.ZeroPadding2D(padding=(kernel_size // 2, kernel_size // 2)),
layers.Conv2D(
filters=self.dim,
kernel_size=kernel_size,
activation=keras.activations.gelu,
groups=self.dim,
use_bias=False,
),
]
)
self.focal_layers.append(depth_gelu_block)
self.kernel_sizes.append(kernel_size)
self.activation = keras.activations.gelu
self.gap = layers.GlobalAveragePooling2D(keepdims=True)
self.modulator_proj = layers.Conv2D(
filters=self.dim,
kernel_size=(1, 1),
use_bias=True,
)
self.proj = layers.Dense(units=self.dim)
self.proj_drop = layers.Dropout(self.proj_drop_rate)
def call(self, x: tf.Tensor, training: Optional[bool] = None) -> tf.Tensor:
"""Forward pass of the layer.
Args:
x: Tensor of shape (B, H, W, C)
"""
# Apply the linear projecion to the input feature map
x_proj = self.initial_proj(x)
# Split the projected x into query, context and gates
query, context, self.gates = tf.split(
value=x_proj,
num_or_size_splits=[self.dim, self.dim, self.focal_level + 1],
axis=-1,
)
# Context aggregation
context = self.focal_layers[0](context)
context_all = context * self.gates[..., 0:1]
for idx in range(1, self.focal_level):
context = self.focal_layers[idx](context)
context_all += context * self.gates[..., idx : idx + 1]
# Build the global context
context_global = self.activation(self.gap(context))
context_all += context_global * self.gates[..., self.focal_level :]
# Focal Modulation
self.modulator = self.modulator_proj(context_all)
x_output = query * self.modulator
# Project the output and apply dropout
x_output = self.proj(x_output)
x_output = self.proj_drop(x_output)
return x_output
最后,我们拥有构建焦点调制块所需的所有组件。在这里,我们将 MLP 和焦点调制层放在一起,构建焦点调制块。
class FocalModulationBlock(layers.Layer):
"""Combine FFN and Focal Modulation Layer.
Args:
dim (int): Number of input channels.
input_resolution (Tuple[int]): Input resulotion.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
drop (float): Dropout rate.
drop_path (float): Stochastic depth rate.
focal_level (int): Number of focal levels.
focal_window (int): Focal window size at first focal level
"""
def __init__(
self,
dim: int,
input_resolution: Tuple[int],
mlp_ratio: float = 4.0,
drop: float = 0.0,
drop_path: float = 0.0,
focal_level: int = 1,
focal_window: int = 3,
**kwargs,
):
super().__init__(**kwargs)
self.dim = dim
self.input_resolution = input_resolution
self.mlp_ratio = mlp_ratio
self.focal_level = focal_level
self.focal_window = focal_window
self.norm = layers.LayerNormalization(epsilon=1e-5)
self.modulation = FocalModulationLayer(
dim=self.dim,
focal_window=self.focal_window,
focal_level=self.focal_level,
proj_drop_rate=drop,
)
mlp_hidden_dim = int(self.dim * self.mlp_ratio)
self.mlp = MLP(
in_features=self.dim,
hidden_features=mlp_hidden_dim,
mlp_drop_rate=drop,
)
def call(self, x: tf.Tensor, height: int, width: int, channels: int) -> tf.Tensor:
"""Processes the input tensor through the focal modulation block.
Args:
x (tf.Tensor): Inputs of the shape (B, L, C)
height (int): The height of the feature map
width (int): The width of the feature map
channels (int): The number of channels of the feature map
Returns:
The processed tensor.
"""
shortcut = x
# Focal Modulation
x = tf.reshape(x, shape=(-1, height, width, channels))
x = self.modulation(x)
x = tf.reshape(x, shape=(-1, height * width, channels))
# FFN
x = shortcut + x
x = x + self.mlp(self.norm(x))
return x
基本层由一组焦点调制块组成。这在图 9 中有所说明。
![]() |
---|
图 9:基本层,一组焦点调制块。(来源:Aritra 和 Ritwik) |
请注意,在图 9 中,有多个焦点调制块,用 Nx
表示。这表明基本层是焦点调制块的集合。
class BasicLayer(layers.Layer):
"""Collection of Focal Modulation Blocks.
Args:
dim (int): Dimensions of the model.
out_dim (int): Dimension used by the Patch Embedding Layer.
input_resolution (Tuple[int]): Input image resolution.
depth (int): The number of Focal Modulation Blocks.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
drop (float): Dropout rate.
downsample (tf.keras.layers.Layer): Downsampling layer at the end of the layer.
focal_level (int): The current focal level.
focal_window (int): Focal window used.
"""
def __init__(
self,
dim: int,
out_dim: int,
input_resolution: Tuple[int],
depth: int,
mlp_ratio: float = 4.0,
drop: float = 0.0,
downsample=None,
focal_level: int = 1,
focal_window: int = 1,
**kwargs,
):
super().__init__(**kwargs)
self.dim = dim
self.input_resolution = input_resolution
self.depth = depth
self.blocks = [
FocalModulationBlock(
dim=dim,
input_resolution=input_resolution,
mlp_ratio=mlp_ratio,
drop=drop,
focal_level=focal_level,
focal_window=focal_window,
)
for i in range(self.depth)
]
# Downsample layer at the end of the layer
if downsample is not None:
self.downsample = downsample(
image_size=input_resolution,
patch_size=(2, 2),
embed_dim=out_dim,
)
else:
self.downsample = None
def call(
self, x: tf.Tensor, height: int, width: int, channels: int
) -> Tuple[tf.Tensor, int, int, int]:
"""Forward pass of the layer.
Args:
x (tf.Tensor): Tensor of shape (B, L, C)
height (int): Height of feature map
width (int): Width of feature map
channels (int): Embed Dim of feature map
Returns:
A tuple of the processed tensor, changed height, width, and
dim of the tensor.
"""
# Apply Focal Modulation Blocks
for block in self.blocks:
x = block(x, height, width, channels)
# Except the last Basic Layer, all the layers have
# downsample at the end of it.
if self.downsample is not None:
x = tf.reshape(x, shape=(-1, height, width, channels))
x, height_o, width_o, channels_o = self.downsample(x)
else:
height_o, width_o, channels_o = height, width, channels
return x, height_o, width_o, channels_o
这个模型将所有内容联系在一起。它由一组带有分类头的基本层组成。有关其结构的概述,请参阅图 1。
class FocalModulationNetwork(keras.Model):
"""The Focal Modulation Network.
Parameters:
image_size (Tuple[int]): Spatial size of images used.
patch_size (Tuple[int]): Patch size of each patch.
num_classes (int): Number of classes used for classification.
embed_dim (int): Patch embedding dimension.
depths (List[int]): Depth of each Focal Transformer block.
mlp_ratio (float): Ratio of expansion for the intermediate layer of MLP.
drop_rate (float): The dropout rate for FM and MLP layers.
focal_levels (list): How many focal levels at all stages.
Note that this excludes the finest-grain level.
focal_windows (list): The focal window size at all stages.
"""
def __init__(
self,
image_size: Tuple[int] = (48, 48),
patch_size: Tuple[int] = (4, 4),
num_classes: int = 10,
embed_dim: int = 256,
depths: List[int] = [2, 3, 2],
mlp_ratio: float = 4.0,
drop_rate: float = 0.1,
focal_levels=[2, 2, 2],
focal_windows=[3, 3, 3],
**kwargs,
):
super().__init__(**kwargs)
self.num_layers = len(depths)
embed_dim = [embed_dim * (2**i) for i in range(self.num_layers)]
self.num_classes = num_classes
self.embed_dim = embed_dim
self.num_features = embed_dim[-1]
self.mlp_ratio = mlp_ratio
self.patch_embed = PatchEmbed(
image_size=image_size,
patch_size=patch_size,
embed_dim=embed_dim[0],
)
num_patches = self.patch_embed.num_patches
patches_resolution = self.patch_embed.patch_resolution
self.patches_resolution = patches_resolution
self.pos_drop = layers.Dropout(drop_rate)
self.basic_layers = list()
for i_layer in range(self.num_layers):
layer = BasicLayer(
dim=embed_dim[i_layer],
out_dim=embed_dim[i_layer + 1]
if (i_layer < self.num_layers - 1)
else None,
input_resolution=(
patches_resolution[0] // (2**i_layer),
patches_resolution[1] // (2**i_layer),
),
depth=depths[i_layer],
mlp_ratio=self.mlp_ratio,
drop=drop_rate,
downsample=PatchEmbed if (i_layer < self.num_layers - 1) else None,
focal_level=focal_levels[i_layer],
focal_window=focal_windows[i_layer],
)
self.basic_layers.append(layer)
self.norm = keras.layers.LayerNormalization(epsilon=1e-7)
self.avgpool = layers.GlobalAveragePooling1D()
self.flatten = layers.Flatten()
self.head = layers.Dense(self.num_classes, activation="softmax")
def call(self, x: tf.Tensor) -> tf.Tensor:
"""Forward pass of the layer.
Args:
x: Tensor of shape (B, H, W, C)
Returns:
The logits.
"""
# Patch Embed the input images.
x, height, width, channels = self.patch_embed(x)
x = self.pos_drop(x)
for idx, layer in enumerate(self.basic_layers):
x, height, width, channels = layer(x, height, width, channels)
x = self.norm(x)
x = self.avgpool(x)
x = self.flatten(x)
x = self.head(x)
return x
现在所有组件都已就位并且架构已经构建完成,我们准备好将其投入使用。
在本节中,我们将在 CIFAR-10 数据集上训练我们的焦点调制模型。
焦点调制网络的一个关键特性是显式的输入依赖性。这意味着调制器是通过查看目标位置周围的局部特征来计算的,因此它依赖于输入。简单来说,这使得解释变得容易。我们可以简单地将门控值和原始图像并排放置,以查看门控机制的工作原理。
论文作者可视化了门控和调制器,以关注焦点调制层的可解释性。下面是一个可视化回调,它在模型训练时显示模型中特定层的门控和调制器。
稍后我们会注意到,随着模型训练,可视化效果会越来越好。
门控似乎选择性地允许输入图像的某些方面通过,同时轻轻地忽略其他方面,最终导致分类准确率的提高。
def display_grid(
test_images: tf.Tensor,
gates: tf.Tensor,
modulator: tf.Tensor,
):
"""Displays the image with the gates and modulator overlayed.
Args:
test_images (tf.Tensor): A batch of test images.
gates (tf.Tensor): The gates of the Focal Modualtion Layer.
modulator (tf.Tensor): The modulator of the Focal Modulation Layer.
"""
fig, ax = plt.subplots(nrows=1, ncols=5, figsize=(25, 5))
# Radomly sample an image from the batch.
index = randint(0, BATCH_SIZE - 1)
orig_image = test_images[index]
gate_image = gates[index]
modulator_image = modulator[index]
# Original Image
ax[0].imshow(orig_image)
ax[0].set_title("Original:")
ax[0].axis("off")
for index in range(1, 5):
img = ax[index].imshow(orig_image)
if index != 4:
overlay_image = gate_image[..., index - 1]
title = f"G {index}:"
else:
overlay_image = tf.norm(modulator_image, ord=2, axis=-1)
title = f"MOD:"
ax[index].imshow(
overlay_image, cmap="inferno", alpha=0.6, extent=img.get_extent()
)
ax[index].set_title(title)
ax[index].axis("off")
plt.axis("off")
plt.show()
plt.close()
# Taking a batch of test inputs to measure the model's progress.
test_images, test_labels = next(iter(test_ds))
upsampler = tf.keras.layers.UpSampling2D(
size=(4, 4),
interpolation="bilinear",
)
class TrainMonitor(keras.callbacks.Callback):
def __init__(self, epoch_interval=None):
self.epoch_interval = epoch_interval
def on_epoch_end(self, epoch, logs=None):
if self.epoch_interval and epoch % self.epoch_interval == 0:
_ = self.model(test_images)
# Take the mid layer for visualization
gates = self.model.basic_layers[1].blocks[-1].modulation.gates
gates = upsampler(gates)
modulator = self.model.basic_layers[1].blocks[-1].modulation.modulator
modulator = upsampler(modulator)
# Display the grid of gates and modulator.
display_grid(test_images=test_images, gates=gates, modulator=modulator)
# Some code is taken from:
# https://www.kaggle.com/ashusma/training-rfcx-tensorflow-tpu-effnet-b2.
class WarmUpCosine(keras.optimizers.schedules.LearningRateSchedule):
def __init__(
self, learning_rate_base, total_steps, warmup_learning_rate, warmup_steps
):
super().__init__()
self.learning_rate_base = learning_rate_base
self.total_steps = total_steps
self.warmup_learning_rate = warmup_learning_rate
self.warmup_steps = warmup_steps
self.pi = tf.constant(np.pi)
def __call__(self, step):
if self.total_steps < self.warmup_steps:
raise ValueError("Total_steps must be larger or equal to warmup_steps.")
cos_annealed_lr = tf.cos(
self.pi
* (tf.cast(step, tf.float32) - self.warmup_steps)
/ float(self.total_steps - self.warmup_steps)
)
learning_rate = 0.5 * self.learning_rate_base * (1 + cos_annealed_lr)
if self.warmup_steps > 0:
if self.learning_rate_base < self.warmup_learning_rate:
raise ValueError(
"Learning_rate_base must be larger or equal to "
"warmup_learning_rate."
)
slope = (
self.learning_rate_base - self.warmup_learning_rate
) / self.warmup_steps
warmup_rate = slope * tf.cast(step, tf.float32) + self.warmup_learning_rate
learning_rate = tf.where(
step < self.warmup_steps, warmup_rate, learning_rate
)
return tf.where(
step > self.total_steps, 0.0, learning_rate, name="learning_rate"
)
total_steps = int((len(x_train) / BATCH_SIZE) * EPOCHS)
warmup_epoch_percentage = 0.15
warmup_steps = int(total_steps * warmup_epoch_percentage)
scheduled_lrs = WarmUpCosine(
learning_rate_base=LEARNING_RATE,
total_steps=total_steps,
warmup_learning_rate=0.0,
warmup_steps=warmup_steps,
)
focal_mod_net = FocalModulationNetwork()
optimizer = AdamW(learning_rate=scheduled_lrs, weight_decay=WEIGHT_DECAY)
# Compile and train the model.
focal_mod_net.compile(
optimizer=optimizer,
loss="sparse_categorical_crossentropy",
metrics=["accuracy"],
)
history = focal_mod_net.fit(
train_ds,
epochs=EPOCHS,
validation_data=val_ds,
callbacks=[TrainMonitor(epoch_interval=10)],
)
Epoch 1/25
40/40 [==============================] - ETA: 0s - loss: 2.3925 - accuracy: 0.1401
40/40 [==============================] - 57s 724ms/step - loss: 2.3925 - accuracy: 0.1401 - val_loss: 2.2182 - val_accuracy: 0.1768
Epoch 2/25
40/40 [==============================] - 20s 483ms/step - loss: 2.0790 - accuracy: 0.2261 - val_loss: 2.2933 - val_accuracy: 0.1795
Epoch 3/25
40/40 [==============================] - 19s 479ms/step - loss: 2.0130 - accuracy: 0.2585 - val_loss: 2.6833 - val_accuracy: 0.2022
Epoch 4/25
40/40 [==============================] - 21s 507ms/step - loss: 1.8270 - accuracy: 0.3315 - val_loss: 1.9127 - val_accuracy: 0.3215
Epoch 5/25
40/40 [==============================] - 19s 475ms/step - loss: 1.6037 - accuracy: 0.4173 - val_loss: 1.7226 - val_accuracy: 0.3938
Epoch 6/25
40/40 [==============================] - 19s 476ms/step - loss: 1.4758 - accuracy: 0.4658 - val_loss: 1.5097 - val_accuracy: 0.4733
Epoch 7/25
40/40 [==============================] - 19s 476ms/step - loss: 1.3677 - accuracy: 0.5075 - val_loss: 1.4630 - val_accuracy: 0.4986
Epoch 8/25
40/40 [==============================] - 21s 508ms/step - loss: 1.2599 - accuracy: 0.5490 - val_loss: 1.2908 - val_accuracy: 0.5492
Epoch 9/25
40/40 [==============================] - 19s 478ms/step - loss: 1.1689 - accuracy: 0.5818 - val_loss: 1.2750 - val_accuracy: 0.5518
Epoch 10/25
40/40 [==============================] - 19s 476ms/step - loss: 1.0843 - accuracy: 0.6140 - val_loss: 1.1444 - val_accuracy: 0.6002
Epoch 11/25
39/40 [============================>.] - ETA: 0s - loss: 1.0040 - accuracy: 0.6453
40/40 [==============================] - 20s 489ms/step - loss: 1.0041 - accuracy: 0.6452 - val_loss: 1.1765 - val_accuracy: 0.5939
Epoch 12/25
40/40 [==============================] - 20s 480ms/step - loss: 0.9401 - accuracy: 0.6701 - val_loss: 1.1276 - val_accuracy: 0.6181
Epoch 13/25
40/40 [==============================] - 19s 480ms/step - loss: 0.8787 - accuracy: 0.6910 - val_loss: 0.9990 - val_accuracy: 0.6547
Epoch 14/25
40/40 [==============================] - 19s 479ms/step - loss: 0.8198 - accuracy: 0.7122 - val_loss: 1.0074 - val_accuracy: 0.6562
Epoch 15/25
40/40 [==============================] - 19s 480ms/step - loss: 0.7831 - accuracy: 0.7275 - val_loss: 0.9739 - val_accuracy: 0.6686
Epoch 16/25
40/40 [==============================] - 19s 478ms/step - loss: 0.7358 - accuracy: 0.7428 - val_loss: 0.9578 - val_accuracy: 0.6753
Epoch 17/25
40/40 [==============================] - 19s 478ms/step - loss: 0.7018 - accuracy: 0.7557 - val_loss: 0.9414 - val_accuracy: 0.6789
Epoch 18/25
40/40 [==============================] - 20s 480ms/step - loss: 0.6678 - accuracy: 0.7678 - val_loss: 0.9492 - val_accuracy: 0.6771
Epoch 19/25
40/40 [==============================] - 19s 476ms/step - loss: 0.6423 - accuracy: 0.7783 - val_loss: 0.9422 - val_accuracy: 0.6832
Epoch 20/25
40/40 [==============================] - 19s 479ms/step - loss: 0.6202 - accuracy: 0.7868 - val_loss: 0.9324 - val_accuracy: 0.6860
Epoch 21/25
40/40 [==============================] - ETA: 0s - loss: 0.6005 - accuracy: 0.7938
40/40 [==============================] - 20s 488ms/step - loss: 0.6005 - accuracy: 0.7938 - val_loss: 0.9326 - val_accuracy: 0.6880
Epoch 22/25
40/40 [==============================] - 19s 478ms/step - loss: 0.5937 - accuracy: 0.7970 - val_loss: 0.9339 - val_accuracy: 0.6875
Epoch 23/25
40/40 [==============================] - 19s 478ms/step - loss: 0.5899 - accuracy: 0.7984 - val_loss: 0.9294 - val_accuracy: 0.6894
Epoch 24/25
40/40 [==============================] - 19s 478ms/step - loss: 0.5840 - accuracy: 0.8012 - val_loss: 0.9315 - val_accuracy: 0.6881
Epoch 25/25
40/40 [==============================] - 19s 478ms/step - loss: 0.5853 - accuracy: 0.7997 - val_loss: 0.9315 - val_accuracy: 0.6880
plt.plot(history.history["loss"], label="loss")
plt.plot(history.history["val_loss"], label="val_loss")
plt.legend()
plt.show()
plt.plot(history.history["accuracy"], label="accuracy")
plt.plot(history.history["val_accuracy"], label="val_accuracy")
plt.legend()
plt.show()
让我们在一些测试图像上测试我们的模型,看看门控是什么样子。
test_images, test_labels = next(iter(test_ds))
_ = focal_mod_net(test_images)
# Take the mid layer for visualization
gates = focal_mod_net.basic_layers[1].blocks[-1].modulation.gates
gates = upsampler(gates)
modulator = focal_mod_net.basic_layers[1].blocks[-1].modulation.modulator
modulator = upsampler(modulator)
# Plot the test images with the gates and modulator overlayed.
for row in range(5):
display_grid(
test_images=test_images,
gates=gates,
modulator=modulator,
)
所提出的架构,即焦点调制网络架构,是一种允许图像不同部分以依赖于图像本身的方式相互作用的机制。它首先围绕图像的每个部分(“查询 token”)收集不同级别的上下文信息,然后使用一个门控来决定哪些上下文信息最相关,最后以简单而有效的方式组合所选信息。
这旨在取代 Transformer 架构中的自注意力机制。这项研究值得注意的关键特征不是无注意力网络的构思,而是引入了一个同样强大且可解释的架构。
作者还提到,他们创建了一系列焦点调制网络(FocalNets),这些网络在参数和预训练数据方面远少于自注意力对应的网络,但性能显著优于后者。
FocalNets 架构有可能带来令人印象深刻的结果,并提供简单的实现。其有前景的性能和易用性使其成为研究人员在其项目中探索自注意力的有吸引力的替代方案。它有可能在不久的将来被深度学习社区广泛采用。
我们要感谢 PyImageSearch 提供 Colab Pro 账户,JarvisLabs.ai 提供 GPU 积分,以及微软研究院提供其论文的 官方实现。我们还要感谢论文的第一作者 Jianwei Yang 对本教程进行了广泛审阅。