代码示例 / 计算机视觉 / 当循环神经网络遇见 Transformer

当循环神经网络遇见 Transformer

作者: Aritra Roy GosthipatySuvaditya Mukherjee
创建日期 2023/03/12
上次修改日期 2023/03/12
描述:使用时间延迟瓶颈网络进行图像分类。

ⓘ 此示例使用 Keras 2

在 Colab 中查看 GitHub 源代码


简介

一个简单的循环神经网络 (RNN) 表现出很强的归纳偏置,倾向于学习**时间压缩表示**。**公式 1** 显示了递归公式,其中h_t是整个输入序列x的压缩表示(单个向量)。

Equation of RNN
公式 1:递归方程。(来源:Aritra 和 Suvaditya)

另一方面,Transformer(Vaswani 等人)几乎没有学习时间压缩表示的归纳偏置。Transformer 通过其成对注意力机制在自然语言处理 (NLP) 和视觉任务中取得了最先进的结果。

虽然 Transformer 能够**关注**输入序列的不同部分,但注意力的计算本质上是二次方的。

Didolkar 等人 认为,拥有序列的更压缩表示可能有利于*泛化*,因为它可以轻松地**重用**和**重新利用**,而无需包含不相关的细节。虽然压缩是有益的,但他们也注意到过度的压缩会损害表达能力。

作者提出了一种将计算划分为**两个流**的解决方案。一个具有循环特性的慢速流和一个参数化为 Transformer 的快速流。虽然此方法在引入不同的处理流以保留和处理潜在状态方面具有新颖性,但它与其他工作(如Perceiver 机制(由 Jaegle 等人提出)基于语言学习的快速和慢速(由 Hill 等人提出))存在相似之处。

以下示例探讨了我们如何利用新的时间延迟瓶颈机制在 CIFAR-10 数据集上执行图像分类。我们通过自定义RNNCell实现来实现此模型,以构建**高效**且**矢量化**的设计。

注意:此示例使用TensorFlow 2.12.0,必须将其安装到我们的系统中。


设置导入

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras import mixed_precision
from tensorflow.keras.optimizers import AdamW

import random
from matplotlib import pyplot as plt

# Set seed for reproducibility.
keras.utils.set_random_seed(42)

AUTO = tf.data.AUTOTUNE

设置所需的配置

我们设置了一些在设计好的管道中需要的配置参数。当前参数用于CIFAR10 数据集

该模型还支持混合精度设置,这将使模型能够在可能的情况下使用16 位浮点数,同时根据需要保留一些参数为32 位以确保数值稳定性。这带来了性能优势,因为模型的占用空间显著减少,同时在推理时也带来了速度提升。

config = {
    "mixed_precision": True,
    "dataset": "cifar10",
    "train_slice": 40_000,
    "batch_size": 2048,
    "buffer_size": 2048 * 2,
    "input_shape": [32, 32, 3],
    "image_size": 48,
    "num_classes": 10,
    "learning_rate": 1e-4,
    "weight_decay": 1e-4,
    "epochs": 30,
    "patch_size": 4,
    "embed_dim": 64,
    "chunk_size": 8,
    "r": 2,
    "num_layers": 4,
    "ffn_drop": 0.2,
    "attn_drop": 0.2,
    "num_heads": 1,
}

if config["mixed_precision"]:
    policy = mixed_precision.Policy("mixed_float16")
    mixed_precision.set_global_policy(policy)
INFO:tensorflow:Mixed precision compatibility check (mixed_float16): OK
Your GPU will likely run quickly with dtype policy mixed_float16 as it has compute capability of at least 7.0. Your GPU: NVIDIA A100-PCIE-40GB, compute capability 8.0

加载 CIFAR-10 数据集

我们将使用 CIFAR10 数据集来运行我们的实验。此数据集包含50,000张用于10类的训练图像,标准图像大小为(32, 32, 3)

它还有一个包含10,000张具有类似特征的图像的单独数据集。有关数据集的更多信息,可以在数据集的官方网站以及keras.datasets.cifar10 API 参考中找到。

(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
(x_train, y_train), (x_val, y_val) = (
    (x_train[: config["train_slice"]], y_train[: config["train_slice"]]),
    (x_train[config["train_slice"] :], y_train[config["train_slice"] :]),
)

为训练和验证/测试管道定义数据增强

我们为对数据执行图像增强定义了单独的管道。此步骤对于使模型对变化更加稳健非常重要,有助于其更好地泛化。我们执行的预处理和增强步骤如下:

  • 重新缩放(训练,测试):此步骤用于将所有图像像素值从[0,255]范围归一化到[0,1)。这有助于在后续训练期间保持数值稳定性。
  • 调整大小(训练,测试):我们将图像从原始大小 (32, 32) 调整为 (52, 52)。这样做是为了考虑随机裁剪,并符合论文中给出的数据规范。
  • 随机裁剪(训练):此层随机选择大小为(48, 48)的图像的裁剪/子区域。
  • 随机翻转(训练):此层随机水平翻转所有图像,保持图像大小不变。
# Build the `train` augmentation pipeline.
train_augmentation = keras.Sequential(
    [
        layers.Rescaling(1 / 255.0, dtype="float32"),
        layers.Resizing(
            config["input_shape"][0] + 20,
            config["input_shape"][0] + 20,
            dtype="float32",
        ),
        layers.RandomCrop(config["image_size"], config["image_size"], dtype="float32"),
        layers.RandomFlip("horizontal", dtype="float32"),
    ],
    name="train_data_augmentation",
)

# Build the `val` and `test` data pipeline.
test_augmentation = keras.Sequential(
    [
        layers.Rescaling(1 / 255.0, dtype="float32"),
        layers.Resizing(config["image_size"], config["image_size"], dtype="float32"),
    ],
    name="test_data_augmentation",
)

# We define functions in place of simple lambda functions to run through the
# [`keras.Sequential`](/api/models/sequential#sequential-class)in order to solve this warning:
# (https://github.com/tensorflow/tensorflow/issues/56089)


def train_map_fn(image, label):
    return train_augmentation(image), label


def test_map_fn(image, label):
    return test_augmentation(image), label

将数据集加载到tf.data.Dataset对象中

train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_ds = (
    train_ds.map(train_map_fn, num_parallel_calls=AUTO)
    .shuffle(config["buffer_size"])
    .batch(config["batch_size"], num_parallel_calls=AUTO)
    .prefetch(AUTO)
)

val_ds = tf.data.Dataset.from_tensor_slices((x_val, y_val))
val_ds = (
    val_ds.map(test_map_fn, num_parallel_calls=AUTO)
    .batch(config["batch_size"], num_parallel_calls=AUTO)
    .prefetch(AUTO)
)

test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test))
test_ds = (
    test_ds.map(test_map_fn, num_parallel_calls=AUTO)
    .batch(config["batch_size"], num_parallel_calls=AUTO)
    .prefetch(AUTO)
)

时间延迟瓶颈 (Temporal Latent Bottleneck)

论文摘录

在大脑中,短期记忆和长期记忆以一种特殊的方式发展。短期记忆可以非常快速地改变以应对即时的感官输入和感知。相比之下,长期记忆变化缓慢,具有高度选择性,并且涉及反复巩固。

受短期记忆和长期记忆的启发,作者引入了快速流和慢速流计算。快速流具有高容量的短期记忆,可以快速响应感官输入(Transformer)。慢速流具有长期记忆,以较慢的速度更新并总结最相关的信息(循环神经网络)。

为了实现这个想法,我们需要

  • 获取一段数据序列。
  • 将序列划分为固定大小的块。
  • 快速流在每个块内运行。它提供细粒度的局部信息。
  • 慢速流整合和聚合跨块的信息。它提供粗粒度的远距离信息。

快速流和慢速流产生了所谓的信息不对称。这两个流通过注意力瓶颈相互作用。图1显示了模型的架构。

Architecture of the model
图1:模型架构。(来源:https://arxiv.org/abs/2205.14794)

作者还提出了一个PyTorch风格的伪代码,如算法1所示。

Pseudocode of the model
算法1:PyTorch风格伪代码。(来源:https://arxiv.org/abs/2205.14794)

PatchEmbedding

这个自定义的keras.layers.Layer用于从图像生成补丁,并使用keras.layers.Embedding将其转换为更高维的嵌入空间。补丁操作使用keras.layers.Conv2D实例而不是传统的tf.image.extract_patches来实现向量化。

完成图像的补丁划分后,我们将图像补丁重新整形以获得扁平化的表示,其中维度数是嵌入维度。在此阶段,我们还将位置信息注入到标记中。

获得标记后,我们将它们进行分块。分块操作涉及从嵌入输出中获取固定大小的序列以创建“块”,这些块将作为模型的最终输入。

class PatchEmbedding(layers.Layer):
    """Image to Patch Embedding.
    Args:
        image_size (`Tuple[int]`): Size of the input image.
        patch_size (`Tuple[int]`): Size of the patch.
        embed_dim (`int`): Dimension of the embedding.
        chunk_size (`int`): Number of patches to be chunked.
    """

    def __init__(
        self,
        image_size,
        patch_size,
        embed_dim,
        chunk_size,
        **kwargs,
    ):
        super().__init__(**kwargs)

        # Compute the patch resolution.
        patch_resolution = [
            image_size[0] // patch_size[0],
            image_size[1] // patch_size[1],
        ]

        # Store the parameters.
        self.image_size = image_size
        self.patch_size = patch_size
        self.embed_dim = embed_dim
        self.patch_resolution = patch_resolution
        self.num_patches = patch_resolution[0] * patch_resolution[1]

        # Define the positions of the patches.
        self.positions = tf.range(start=0, limit=self.num_patches, delta=1)

        # Create the layers.
        self.projection = layers.Conv2D(
            filters=embed_dim,
            kernel_size=patch_size,
            strides=patch_size,
            name="projection",
        )
        self.flatten = layers.Reshape(
            target_shape=(-1, embed_dim),
            name="flatten",
        )
        self.position_embedding = layers.Embedding(
            input_dim=self.num_patches,
            output_dim=embed_dim,
            name="position_embedding",
        )
        self.layernorm = keras.layers.LayerNormalization(
            epsilon=1e-5,
            name="layernorm",
        )
        self.chunking_layer = layers.Reshape(
            target_shape=(self.num_patches // chunk_size, chunk_size, embed_dim),
            name="chunking_layer",
        )

    def call(self, inputs):
        # Project the inputs to the embedding dimension.
        x = self.projection(inputs)

        # Flatten the pathces and add position embedding.
        x = self.flatten(x)
        x = x + self.position_embedding(self.positions)

        # Normalize the embeddings.
        x = self.layernorm(x)

        # Chunk the tokens.
        x = self.chunking_layer(x)

        return x

FeedForwardNetwork

这个自定义的keras.layers.Layer实例允许我们定义一个通用的FFN以及一个dropout。

class FeedForwardNetwork(layers.Layer):
    """Feed Forward Network.
    Args:
        dims (`int`): Number of units in FFN.
        dropout (`float`): Dropout probability for FFN.
    """

    def __init__(self, dims, dropout, **kwargs):
        super().__init__(**kwargs)

        # Create the layers.
        self.ffn = keras.Sequential(
            [
                layers.Dense(units=4 * dims, activation=tf.nn.gelu),
                layers.Dense(units=dims),
                layers.Dropout(rate=dropout),
            ],
            name="ffn",
        )
        self.layernorm = layers.LayerNormalization(
            epsilon=1e-5,
            name="layernorm",
        )

    def call(self, inputs):
        # Apply the FFN.
        x = self.layernorm(inputs)
        x = inputs + self.ffn(x)
        return x

BaseAttention

这个自定义的keras.layers.Layer实例是一个super/base类,它封装了一个keras.layers.MultiHeadAttention层以及其他一些组件。这为我们模型中的所有注意力层/模块提供了基本的公共功能。

class BaseAttention(layers.Layer):
    """Base Attention Module.
    Args:
        num_heads (`int`): Number of attention heads.
        key_dim (`int`): Size of each attention head for key.
        dropout (`float`): Dropout probability for attention module.
    """

    def __init__(self, num_heads, key_dim, dropout, **kwargs):
        super().__init__(**kwargs)
        self.multi_head_attention = layers.MultiHeadAttention(
            num_heads=num_heads,
            key_dim=key_dim,
            dropout=dropout,
            name="mha",
        )
        self.query_layernorm = layers.LayerNormalization(
            epsilon=1e-5,
            name="q_layernorm",
        )
        self.key_layernorm = layers.LayerNormalization(
            epsilon=1e-5,
            name="k_layernorm",
        )
        self.value_layernorm = layers.LayerNormalization(
            epsilon=1e-5,
            name="v_layernorm",
        )

        self.attention_scores = None

    def call(self, input_query, key, value):
        # Apply the attention module.
        query = self.query_layernorm(input_query)
        key = self.key_layernorm(key)
        value = self.value_layernorm(value)
        (attention_outputs, attention_scores) = self.multi_head_attention(
            query=query,
            key=key,
            value=value,
            return_attention_scores=True,
        )

        # Save the attention scores for later visualization.
        self.attention_scores = attention_scores

        # Add the input to the attention output.
        x = input_query + attention_outputs
        return x

带有FeedForwardNetwork层的Attention

这个自定义的keras.layers.Layer实现结合了BaseAttentionFeedForwardNetwork组件来开发一个将在模型中重复使用的块。此模块高度可定制且灵活,允许在内部层中进行更改。

class AttentionWithFFN(layers.Layer):
    """Attention with Feed Forward Network.
    Args:
        ffn_dims (`int`): Number of units in FFN.
        ffn_dropout (`float`): Dropout probability for FFN.
        num_heads (`int`): Number of attention heads.
        key_dim (`int`): Size of each attention head for key.
        attn_dropout (`float`): Dropout probability for attention module.
    """

    def __init__(
        self,
        ffn_dims,
        ffn_dropout,
        num_heads,
        key_dim,
        attn_dropout,
        **kwargs,
    ):
        super().__init__(**kwargs)
        # Create the layers.
        self.attention = BaseAttention(
            num_heads=num_heads,
            key_dim=key_dim,
            dropout=attn_dropout,
            name="base_attn",
        )
        self.ffn = FeedForwardNetwork(
            dims=ffn_dims,
            dropout=ffn_dropout,
            name="ffn",
        )

        self.attention_scores = None

    def call(self, query, key, value):
        # Apply the attention module.
        x = self.attention(query, key, value)

        # Save the attention scores for later visualization.
        self.attention_scores = self.attention.attention_scores

        # Apply the FFN.
        x = self.ffn(x)
        return x

用于时间延迟瓶颈感知模块的自定义RNN单元

算法1(伪代码)使用循环来描述循环。循环使实现更简单,但会影响训练时间。在本节中,我们将自定义循环逻辑封装在CustomRecurrentCell中。然后,此自定义单元将与Keras RNN API一起封装,使整个代码可向量化。

这个作为keras.layers.Layer实现的自定义单元是模型逻辑的组成部分。单元的功能可以分为两部分:- 慢速流(时间延迟瓶颈):

  • 此模块包含一个AttentionWithFFN层,该层解析先前慢速流的输出,将中间隐藏表示(即时间延迟瓶颈中的潜在)作为查询,并将最新快速流的输出作为键和值。此层也可以理解为一个交叉注意力层。
  • 快速流(感知模块)
  • 此模块包含交织的AttentionWithFFN层。此流由以顺序方式排列的nSelfAttentionCrossAttention层组成。
  • 在这里,某些层将分块输入作为查询、键和值(也称为SelfAttention层)。
  • 其他层将时间延迟瓶颈模块中的中间状态输出作为查询,同时使用其之前的Self-Attention层的输出作为键和值。
class CustomRecurrentCell(layers.Layer):
    """Custom Recurrent Cell.
    Args:
        chunk_size (`int`): Number of tokens in a chunk.
        r (`int`): One Cross Attention per **r** Self Attention.
        num_layers (`int`): Number of layers.
        ffn_dims (`int`): Number of units in FFN.
        ffn_dropout (`float`): Dropout probability for FFN.
        num_heads (`int`): Number of attention heads.
        key_dim (`int`): Size of each attention head for key.
        attn_dropout (`float`): Dropout probability for attention module.
    """

    def __init__(
        self,
        chunk_size,
        r,
        num_layers,
        ffn_dims,
        ffn_dropout,
        num_heads,
        key_dim,
        attn_dropout,
        **kwargs,
    ):
        super().__init__(**kwargs)
        # Save the arguments.
        self.chunk_size = chunk_size
        self.r = r
        self.num_layers = num_layers
        self.ffn_dims = ffn_dims
        self.ffn_droput = ffn_dropout
        self.num_heads = num_heads
        self.key_dim = key_dim
        self.attn_dropout = attn_dropout

        # Create the state_size and output_size. This is important for
        # custom recurrence logic.
        self.state_size = tf.TensorShape([chunk_size, ffn_dims])
        self.output_size = tf.TensorShape([chunk_size, ffn_dims])

        self.get_attention_scores = False
        self.attention_scores = []

        # Perceptual Module
        perceptual_module = list()
        for layer_idx in range(num_layers):
            perceptual_module.append(
                AttentionWithFFN(
                    ffn_dims=ffn_dims,
                    ffn_dropout=ffn_dropout,
                    num_heads=num_heads,
                    key_dim=key_dim,
                    attn_dropout=attn_dropout,
                    name=f"pm_self_attn_{layer_idx}",
                )
            )
            if layer_idx % r == 0:
                perceptual_module.append(
                    AttentionWithFFN(
                        ffn_dims=ffn_dims,
                        ffn_dropout=ffn_dropout,
                        num_heads=num_heads,
                        key_dim=key_dim,
                        attn_dropout=attn_dropout,
                        name=f"pm_cross_attn_ffn_{layer_idx}",
                    )
                )
        self.perceptual_module = perceptual_module

        # Temporal Latent Bottleneck Module
        self.tlb_module = AttentionWithFFN(
            ffn_dims=ffn_dims,
            ffn_dropout=ffn_dropout,
            num_heads=num_heads,
            key_dim=key_dim,
            attn_dropout=attn_dropout,
            name=f"tlb_cross_attn_ffn",
        )

    def call(self, inputs, states):
        # inputs => (batch, chunk_size, dims)
        # states => [(batch, chunk_size, units)]
        slow_stream = states[0]
        fast_stream = inputs

        for layer_idx, layer in enumerate(self.perceptual_module):
            fast_stream = layer(query=fast_stream, key=fast_stream, value=fast_stream)

            if layer_idx % self.r == 0:
                fast_stream = layer(
                    query=fast_stream, key=slow_stream, value=slow_stream
                )

        slow_stream = self.tlb_module(
            query=slow_stream, key=fast_stream, value=fast_stream
        )

        # Save the attention scores for later visualization.
        if self.get_attention_scores:
            self.attention_scores.append(self.tlb_module.attention_scores)

        return fast_stream, [slow_stream]

TemporalLatentBottleneckModel封装完整模型

在这里,我们只是包装完整模型以便将其公开用于训练。

class TemporalLatentBottleneckModel(keras.Model):
    """Model Trainer.
    Args:
        patch_layer ([`keras.layers.Layer`](/api/layers/base_layer#layer-class)): Patching layer.
        custom_cell ([`keras.layers.Layer`](/api/layers/base_layer#layer-class)): Custom Recurrent Cell.
    """

    def __init__(self, patch_layer, custom_cell, **kwargs):
        super().__init__(**kwargs)
        self.patch_layer = patch_layer
        self.rnn = layers.RNN(custom_cell, name="rnn")
        self.gap = layers.GlobalAveragePooling1D(name="gap")
        self.head = layers.Dense(10, activation="softmax", dtype="float32", name="head")

    def call(self, inputs):
        x = self.patch_layer(inputs)
        x = self.rnn(x)
        x = self.gap(x)
        outputs = self.head(x)
        return outputs

构建模型

要开始训练,我们现在分别定义组件并将它们作为参数传递给我们的包装器类,该类将为训练准备最终模型。我们定义了一个PatchEmbed层和一个基于CustomCell的RNN。

# Build the model.
patch_layer = PatchEmbedding(
    image_size=(config["image_size"], config["image_size"]),
    patch_size=(config["patch_size"], config["patch_size"]),
    embed_dim=config["embed_dim"],
    chunk_size=config["chunk_size"],
)
custom_rnn_cell = CustomRecurrentCell(
    chunk_size=config["chunk_size"],
    r=config["r"],
    num_layers=config["num_layers"],
    ffn_dims=config["embed_dim"],
    ffn_dropout=config["ffn_drop"],
    num_heads=config["num_heads"],
    key_dim=config["embed_dim"],
    attn_dropout=config["attn_drop"],
)
model = TemporalLatentBottleneckModel(
    patch_layer=patch_layer,
    custom_cell=custom_rnn_cell,
)

指标和回调

我们使用AdamW优化器,因为它已被证明在从优化角度来看的多个基准任务上表现非常好。它是keras.optimizers.Adam优化器的版本,并带有权重衰减。

对于损失函数,我们使用keras.losses.SparseCategoricalCrossentropy函数,该函数利用预测和实际logits之间的简单交叉熵。我们还计算我们数据上的准确率作为健全性检查。

optimizer = AdamW(
    learning_rate=config["learning_rate"], weight_decay=config["weight_decay"]
)
model.compile(
    optimizer=optimizer,
    loss="sparse_categorical_crossentropy",
    metrics=["accuracy"],
)

使用model.fit()训练模型

我们传递训练数据集并运行训练。

history = model.fit(
    train_ds,
    epochs=config["epochs"],
    validation_data=val_ds,
)
Epoch 1/30
20/20 [==============================] - 104s 3s/step - loss: 2.6284 - accuracy: 0.1010 - val_loss: 2.2835 - val_accuracy: 0.1251
Epoch 2/30
20/20 [==============================] - 35s 2s/step - loss: 2.2797 - accuracy: 0.1542 - val_loss: 2.1721 - val_accuracy: 0.1846
Epoch 3/30
20/20 [==============================] - 34s 2s/step - loss: 2.1989 - accuracy: 0.1883 - val_loss: 2.1288 - val_accuracy: 0.2241
Epoch 4/30
20/20 [==============================] - 34s 2s/step - loss: 2.1267 - accuracy: 0.2192 - val_loss: 2.0919 - val_accuracy: 0.2477
Epoch 5/30
20/20 [==============================] - 33s 2s/step - loss: 2.0653 - accuracy: 0.2393 - val_loss: 2.0134 - val_accuracy: 0.2671
Epoch 6/30
20/20 [==============================] - 34s 2s/step - loss: 2.0327 - accuracy: 0.2524 - val_loss: 2.0258 - val_accuracy: 0.2665
Epoch 7/30
20/20 [==============================] - 34s 2s/step - loss: 2.0047 - accuracy: 0.2598 - val_loss: 1.9871 - val_accuracy: 0.2831
Epoch 8/30
20/20 [==============================] - 34s 2s/step - loss: 1.9765 - accuracy: 0.2781 - val_loss: 1.9550 - val_accuracy: 0.2968
Epoch 9/30
20/20 [==============================] - 34s 2s/step - loss: 1.9432 - accuracy: 0.2883 - val_loss: 1.9559 - val_accuracy: 0.2969
Epoch 10/30
20/20 [==============================] - 33s 2s/step - loss: 1.9062 - accuracy: 0.3020 - val_loss: 1.8967 - val_accuracy: 0.3200
Epoch 11/30
20/20 [==============================] - 33s 2s/step - loss: 1.8741 - accuracy: 0.3158 - val_loss: 1.8648 - val_accuracy: 0.3330
Epoch 12/30
20/20 [==============================] - 33s 2s/step - loss: 1.8336 - accuracy: 0.3282 - val_loss: 1.7863 - val_accuracy: 0.3464
Epoch 13/30
20/20 [==============================] - 33s 2s/step - loss: 1.7931 - accuracy: 0.3434 - val_loss: 1.7364 - val_accuracy: 0.3669
Epoch 14/30
20/20 [==============================] - 34s 2s/step - loss: 1.7491 - accuracy: 0.3558 - val_loss: 1.7104 - val_accuracy: 0.3710
Epoch 15/30
20/20 [==============================] - 34s 2s/step - loss: 1.7182 - accuracy: 0.3686 - val_loss: 1.6883 - val_accuracy: 0.3866
Epoch 16/30
20/20 [==============================] - 33s 2s/step - loss: 1.6819 - accuracy: 0.3790 - val_loss: 1.6493 - val_accuracy: 0.3933
Epoch 17/30
20/20 [==============================] - 33s 2s/step - loss: 1.6594 - accuracy: 0.3873 - val_loss: 1.6021 - val_accuracy: 0.4161
Epoch 18/30
20/20 [==============================] - 33s 2s/step - loss: 1.6279 - accuracy: 0.3946 - val_loss: 1.5949 - val_accuracy: 0.4170
Epoch 19/30
20/20 [==============================] - 34s 2s/step - loss: 1.6127 - accuracy: 0.4015 - val_loss: 1.5672 - val_accuracy: 0.4239
Epoch 20/30
20/20 [==============================] - 33s 2s/step - loss: 1.5995 - accuracy: 0.4079 - val_loss: 1.5795 - val_accuracy: 0.4223
Epoch 21/30
20/20 [==============================] - 34s 2s/step - loss: 1.5809 - accuracy: 0.4167 - val_loss: 1.5294 - val_accuracy: 0.4390
Epoch 22/30
20/20 [==============================] - 34s 2s/step - loss: 1.5572 - accuracy: 0.4254 - val_loss: 1.5192 - val_accuracy: 0.4455
Epoch 23/30
20/20 [==============================] - 33s 2s/step - loss: 1.5468 - accuracy: 0.4291 - val_loss: 1.5243 - val_accuracy: 0.4424
Epoch 24/30
20/20 [==============================] - 34s 2s/step - loss: 1.5347 - accuracy: 0.4335 - val_loss: 1.4920 - val_accuracy: 0.4532
Epoch 25/30
20/20 [==============================] - 33s 2s/step - loss: 1.5245 - accuracy: 0.4387 - val_loss: 1.4805 - val_accuracy: 0.4584
Epoch 26/30
20/20 [==============================] - 33s 2s/step - loss: 1.5057 - accuracy: 0.4469 - val_loss: 1.4754 - val_accuracy: 0.4592
Epoch 27/30
20/20 [==============================] - 34s 2s/step - loss: 1.5013 - accuracy: 0.4457 - val_loss: 1.4688 - val_accuracy: 0.4619
Epoch 28/30
20/20 [==============================] - 33s 2s/step - loss: 1.4852 - accuracy: 0.4548 - val_loss: 1.4543 - val_accuracy: 0.4704
Epoch 29/30
20/20 [==============================] - 34s 2s/step - loss: 1.4728 - accuracy: 0.4570 - val_loss: 1.4437 - val_accuracy: 0.4751
Epoch 30/30
20/20 [==============================] - 34s 2s/step - loss: 1.4652 - accuracy: 0.4606 - val_loss: 1.4546 - val_accuracy: 0.4726

可视化训练指标

model.fit()将返回一个history对象,该对象存储在训练运行期间生成的指标的值(但它是短暂的,需要手动保存)。

我们现在显示训练集和验证集的损失和准确率曲线。

plt.plot(history.history["loss"], label="loss")
plt.plot(history.history["val_loss"], label="val_loss")
plt.legend()
plt.show()

plt.plot(history.history["accuracy"], label="accuracy")
plt.plot(history.history["val_accuracy"], label="val_accuracy")
plt.legend()
plt.show()

png

png


可视化时间延迟瓶颈中的注意力图

现在我们已经训练了我们的模型,是时候进行一些可视化了。快速流(Transformer)处理一部分标记。慢速流处理每个块并关注对任务有用的标记。

在本节中,我们可视化慢速流的注意力图。这是通过在每个块的交点处从TLB层提取注意力分数,并将其存储在RNN的状态中来完成的。然后将其“膨胀”并返回这些值。

def score_to_viz(chunk_score):
    # get the most attended token
    chunk_viz = tf.math.reduce_max(chunk_score, axis=-2)
    # get the mean across heads
    chunk_viz = tf.math.reduce_mean(chunk_viz, axis=1)
    return chunk_viz


# Get a batch of images and labels from the testing dataset
images, labels = next(iter(test_ds))

# Set the get_attn_scores flag to True
model.rnn.cell.get_attention_scores = True

# Run the model with the testing images and grab the
# attention scores.
outputs = model(images)
list_chunk_scores = model.rnn.cell.attention_scores

# Process the attention scores in order to visualize them
list_chunk_viz = [score_to_viz(x) for x in list_chunk_scores]
chunk_viz = tf.concat(list_chunk_viz[1:], axis=-1)
chunk_viz = tf.reshape(
    chunk_viz,
    (
        config["batch_size"],
        config["image_size"] // config["patch_size"],
        config["image_size"] // config["patch_size"],
        1,
    ),
)
upsampled_heat_map = layers.UpSampling2D(
    size=(4, 4), interpolation="bilinear", dtype="float32"
)(chunk_viz)

运行以下代码片段以获取不同的图像及其注意力图。

# Sample a random image
index = random.randint(0, config["batch_size"])
orig_image = images[index]
overlay_image = upsampled_heat_map[index, ..., 0]

# Plot the visualization
fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(10, 5))

ax[0].imshow(orig_image)
ax[0].set_title("Original:")
ax[0].axis("off")

image = ax[1].imshow(orig_image)
ax[1].imshow(
    overlay_image,
    cmap="inferno",
    alpha=0.6,
    extent=image.get_extent(),
)
ax[1].set_title("TLB Attention:")

plt.show()

png


结论

此示例演示了时间延迟瓶颈机制的实现。该示例强调了以时间延迟瓶颈的形式压缩和存储历史状态,并定期从感知模块更新作为有效方法。

在原始论文中,作者对从监督图像分类到强化学习应用的不同模态进行了广泛的测试。

虽然我们只显示了一种将此机制应用于图像分类的方法,但它也可以通过最少的更改扩展到其他模态。

注意:在构建此示例时,我们没有官方代码可以参考。这意味着我们的实现受论文启发,不声称是完全复制。有关训练过程的更多详细信息,可以访问我们的GitHub存储库


致谢

感谢Aniket Didolkar(第一作者)和Anirudh Goyal(第三作者)审阅我们的工作。

我们要感谢PyImageSearch提供Colab Pro账户,以及JarvisLabs.ai提供GPU算力。