作者: Aritra Roy Gosthipaty, Suvaditya Mukherjee
创建日期 2023/03/12
最后修改日期 2024/11/12
描述: 使用时序潜在瓶颈网络进行图像分类。
一个简单的循环神经网络(RNN)在学习时间压缩表示方面表现出强大的归纳偏置。公式1显示了循环公式,其中h_t
是整个输入序列x
的压缩表示(一个单一向量)。
![]() |
---|
公式1:循环方程。(来源:Aritra 和 Suvaditya) |
另一方面,Transformer (Vaswani 等人) 在学习时间压缩表示方面几乎没有归纳偏置。Transformer凭借其成对注意力机制在自然语言处理(NLP)和视觉任务中取得了SoTA成果。
虽然Transformer具有关注输入序列不同部分的能力,但注意力的计算本质上是二次的。
Didolkar 等人认为,序列的更压缩表示可能对泛化有益,因为它可以更容易地被重复使用和重新利用,同时减少无关细节。虽然压缩很好,但他们也注意到过度压缩会损害表达能力。
作者提出了一种将计算分为两个流的解决方案。一个慢流是循环性质的,一个快流被参数化为Transformer。虽然这种方法通过引入不同的处理流来保存和处理潜在状态具有新颖性,但在其他工作中也有类似的并行之处,例如感知器机制(Jaegle 等人)和慢速与快速的接地语言学习(Hill 等人)。
以下示例探讨了如何利用新的时序潜在瓶颈机制在CIFAR-10数据集上执行图像分类。我们通过自定义RNNCell
实现来构建此模型,以实现高性能和向量化设计。
import os
import keras
from keras import layers, ops, mixed_precision
from keras.optimizers import AdamW
import numpy as np
import random
from matplotlib import pyplot as plt
# Set seed for reproducibility.
keras.utils.set_random_seed(42)
我们设置了一些在所设计的管道中所需的配置参数。当前参数用于CIFAR10数据集。
该模型还支持混合精度
设置,它将模型量化为在可能的情况下使用16位
浮点数,同时根据数值稳定性需要保留某些参数为32位
。这带来了性能优势,因为模型的内存占用显著减少,同时在推理时带来了速度提升。
config = {
"mixed_precision": True,
"dataset": "cifar10",
"train_slice": 40_000,
"batch_size": 2048,
"buffer_size": 2048 * 2,
"input_shape": [32, 32, 3],
"image_size": 48,
"num_classes": 10,
"learning_rate": 1e-4,
"weight_decay": 1e-4,
"epochs": 30,
"patch_size": 4,
"embed_dim": 64,
"chunk_size": 8,
"r": 2,
"num_layers": 4,
"ffn_drop": 0.2,
"attn_drop": 0.2,
"num_heads": 1,
}
if config["mixed_precision"]:
policy = mixed_precision.Policy("mixed_float16")
mixed_precision.set_global_policy(policy)
我们将使用CIFAR10数据集进行实验。该数据集包含一个训练集,其中包含50,000张图像,分为10个类别,标准图像大小为(32, 32, 3)。
它还包含一个包含10,000张具有相似特征的图像的单独集合。有关数据集的更多信息,请参阅数据集官方网站以及keras.datasets.cifar10
API参考。
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
(x_train, y_train), (x_val, y_val) = (
(x_train[: config["train_slice"]], y_train[: config["train_slice"]]),
(x_train[config["train_slice"] :], y_train[config["train_slice"] :]),
)
我们定义了独立的管道,用于对数据执行图像增强。此步骤对于使模型对变化更具鲁棒性,帮助其更好地泛化至关重要。我们执行的预处理和增强步骤如下:
Rescaling
(训练、测试):此步骤用于将所有图像像素值从[0,255]
范围归一化到[0,1)
。这有助于在训练后期保持数值稳定性。Resizing
(训练、测试):我们将图像从原始大小(32, 32)调整为(52, 52)。这样做是为了考虑到随机裁剪,并符合论文中给出的数据规范。RandomCrop
(训练):此层随机选择图像中大小为(48, 48)
的裁剪/子区域。RandomFlip
(训练):此层随机水平翻转所有图像,保持图像大小不变。# Build the `train` augmentation pipeline.
train_augmentation = keras.Sequential(
[
layers.Rescaling(1 / 255.0, dtype="float32"),
layers.Resizing(
config["input_shape"][0] + 20,
config["input_shape"][0] + 20,
dtype="float32",
),
layers.RandomCrop(config["image_size"], config["image_size"], dtype="float32"),
layers.RandomFlip("horizontal", dtype="float32"),
],
name="train_data_augmentation",
)
# Build the `val` and `test` data pipeline.
test_augmentation = keras.Sequential(
[
layers.Rescaling(1 / 255.0, dtype="float32"),
layers.Resizing(config["image_size"], config["image_size"], dtype="float32"),
],
name="test_data_augmentation",
)
# We define functions in place of simple lambda functions to run through the
# [`keras.Sequential`](/api/models/sequential#sequential-class)in order to solve this warning:
# (https://github.com/tensorflow/tensorflow/issues/56089)
def train_map_fn(image, label):
return train_augmentation(image), label
def test_map_fn(image, label):
return test_augmentation(image), label
PyDataset
对象中np.ndarray
实例并围绕它封装一个类,封装一个keras.utils.PyDataset
,并使用keras预处理层应用增强。class Dataset(keras.utils.PyDataset):
def __init__(
self, x_data, y_data, batch_size, preprocess_fn=None, shuffle=False, **kwargs
):
if shuffle:
perm = np.random.permutation(len(x_data))
x_data = x_data[perm]
y_data = y_data[perm]
self.x_data = x_data
self.y_data = y_data
self.preprocess_fn = preprocess_fn
self.batch_size = batch_size
super().__init__(*kwargs)
def __len__(self):
return len(self.x_data) // self.batch_size
def __getitem__(self, idx):
batch_x, batch_y = [], []
for i in range(idx * self.batch_size, (idx + 1) * self.batch_size):
x, y = self.x_data[i], self.y_data[i]
if self.preprocess_fn:
x, y = self.preprocess_fn(x, y)
batch_x.append(x)
batch_y.append(y)
batch_x = ops.stack(batch_x, axis=0)
batch_y = ops.stack(batch_y, axis=0)
return batch_x, batch_y
train_ds = Dataset(
x_train, y_train, config["batch_size"], preprocess_fn=train_map_fn, shuffle=True
)
val_ds = Dataset(x_val, y_val, config["batch_size"], preprocess_fn=test_map_fn)
test_ds = Dataset(x_test, y_test, config["batch_size"], preprocess_fn=test_map_fn)
论文节选
在大脑中,短期记忆和长期记忆以专门化的方式发展。短期记忆被允许快速变化以响应即时感官输入和感知。相比之下,长期记忆变化缓慢,高度选择性,并涉及重复巩固。
受短期记忆和长期记忆的启发,作者引入了快流和慢流计算。快流具有容量大、对感官输入(Transformer)反应迅速的短期记忆。慢流具有更新速度较慢、总结最相关信息(循环)的长期记忆。
为了实现这个想法,我们需要
快流和慢流导致了所谓的**信息不对称**。这两个流通过注意力的瓶颈相互作用。**图1**展示了模型的架构。
![]() |
---|
图1:模型架构。(来源:https://arxiv.org/abs/2205.14794) |
作者还提出了一种PyTorch风格的伪代码,如**算法1**所示。
![]() |
---|
算法1:PyTorch风格伪代码。(来源:https://arxiv.org/abs/2205.14794) |
PatchEmbedding
层这个自定义的keras.layers.Layer
对于从图像生成块并使用keras.layers.Embedding
将它们转换到更高维的嵌入空间非常有用。分块操作是使用keras.layers.Conv2D
实例完成的。
图像块拼接完成后,我们对图像块进行重塑,以获得一个扁平化的表示,其中维度数量是嵌入维度。在此阶段,我们还将位置信息注入到token中。
我们获得token后,对其进行分块。分块操作包括从嵌入输出中获取固定大小的序列以创建“块”,然后这些块将用作模型的最终输入。
class PatchEmbedding(layers.Layer):
"""Image to Patch Embedding.
Args:
image_size (`Tuple[int]`): Size of the input image.
patch_size (`Tuple[int]`): Size of the patch.
embed_dim (`int`): Dimension of the embedding.
chunk_size (`int`): Number of patches to be chunked.
"""
def __init__(
self,
image_size,
patch_size,
embed_dim,
chunk_size,
**kwargs,
):
super().__init__(**kwargs)
# Compute the patch resolution.
patch_resolution = [
image_size[0] // patch_size[0],
image_size[1] // patch_size[1],
]
# Store the parameters.
self.image_size = image_size
self.patch_size = patch_size
self.embed_dim = embed_dim
self.patch_resolution = patch_resolution
self.num_patches = patch_resolution[0] * patch_resolution[1]
# Define the positions of the patches.
self.positions = ops.arange(start=0, stop=self.num_patches, step=1)
# Create the layers.
self.projection = layers.Conv2D(
filters=embed_dim,
kernel_size=patch_size,
strides=patch_size,
name="projection",
)
self.flatten = layers.Reshape(
target_shape=(-1, embed_dim),
name="flatten",
)
self.position_embedding = layers.Embedding(
input_dim=self.num_patches,
output_dim=embed_dim,
name="position_embedding",
)
self.layernorm = keras.layers.LayerNormalization(
epsilon=1e-5,
name="layernorm",
)
self.chunking_layer = layers.Reshape(
target_shape=(self.num_patches // chunk_size, chunk_size, embed_dim),
name="chunking_layer",
)
def call(self, inputs):
# Project the inputs to the embedding dimension.
x = self.projection(inputs)
# Flatten the pathces and add position embedding.
x = self.flatten(x)
x = x + self.position_embedding(self.positions)
# Normalize the embeddings.
x = self.layernorm(x)
# Chunk the tokens.
x = self.chunking_layer(x)
return x
FeedForwardNetwork
层此自定义keras.layers.Layer
实例允许我们定义一个通用FFN以及一个dropout。
class FeedForwardNetwork(layers.Layer):
"""Feed Forward Network.
Args:
dims (`int`): Number of units in FFN.
dropout (`float`): Dropout probability for FFN.
"""
def __init__(self, dims, dropout, **kwargs):
super().__init__(**kwargs)
# Create the layers.
self.ffn = keras.Sequential(
[
layers.Dense(units=4 * dims, activation="gelu"),
layers.Dense(units=dims),
layers.Dropout(rate=dropout),
],
name="ffn",
)
self.layernorm = layers.LayerNormalization(
epsilon=1e-5,
name="layernorm",
)
def call(self, inputs):
# Apply the FFN.
x = self.layernorm(inputs)
x = inputs + self.ffn(x)
return x
BaseAttention
层此自定义keras.layers.Layer
实例是一个super
/base
类,它封装了keras.layers.MultiHeadAttention
层以及其他一些组件。这为我们模型中所有注意力层/模块提供了基本的公分母功能。
class BaseAttention(layers.Layer):
"""Base Attention Module.
Args:
num_heads (`int`): Number of attention heads.
key_dim (`int`): Size of each attention head for key.
dropout (`float`): Dropout probability for attention module.
"""
def __init__(self, num_heads, key_dim, dropout, **kwargs):
super().__init__(**kwargs)
self.multi_head_attention = layers.MultiHeadAttention(
num_heads=num_heads,
key_dim=key_dim,
dropout=dropout,
name="mha",
)
self.query_layernorm = layers.LayerNormalization(
epsilon=1e-5,
name="q_layernorm",
)
self.key_layernorm = layers.LayerNormalization(
epsilon=1e-5,
name="k_layernorm",
)
self.value_layernorm = layers.LayerNormalization(
epsilon=1e-5,
name="v_layernorm",
)
self.attention_scores = None
def call(self, input_query, key, value):
# Apply the attention module.
query = self.query_layernorm(input_query)
key = self.key_layernorm(key)
value = self.value_layernorm(value)
(attention_outputs, attention_scores) = self.multi_head_attention(
query=query,
key=key,
value=value,
return_attention_scores=True,
)
# Save the attention scores for later visualization.
self.attention_scores = attention_scores
# Add the input to the attention output.
x = input_query + attention_outputs
return x
Attention
与 FeedForwardNetwork
层这个自定义的keras.layers.Layer
实现结合了BaseAttention
和FeedForwardNetwork
组件,以开发一个将在模型中重复使用的块。该模块高度可定制且灵活,允许内部层进行更改。
class AttentionWithFFN(layers.Layer):
"""Attention with Feed Forward Network.
Args:
ffn_dims (`int`): Number of units in FFN.
ffn_dropout (`float`): Dropout probability for FFN.
num_heads (`int`): Number of attention heads.
key_dim (`int`): Size of each attention head for key.
attn_dropout (`float`): Dropout probability for attention module.
"""
def __init__(
self,
ffn_dims,
ffn_dropout,
num_heads,
key_dim,
attn_dropout,
**kwargs,
):
super().__init__(**kwargs)
# Create the layers.
self.fast_stream_attention = BaseAttention(
num_heads=num_heads,
key_dim=key_dim,
dropout=attn_dropout,
name="base_attn",
)
self.slow_stream_attention = BaseAttention(
num_heads=num_heads,
key_dim=key_dim,
dropout=attn_dropout,
name="base_attn",
)
self.ffn = FeedForwardNetwork(
dims=ffn_dims,
dropout=ffn_dropout,
name="ffn",
)
self.attention_scores = None
def build(self, input_shape):
self.built = True
def call(self, query, key, value, stream="fast"):
# Apply the attention module.
attention_layer = {
"fast": self.fast_stream_attention,
"slow": self.slow_stream_attention,
}[stream]
if len(query.shape) == 2:
query = ops.expand_dims(query, -1)
if len(key.shape) == 2:
key = ops.expand_dims(key, -1)
if len(value.shape) == 2:
value = ops.expand_dims(value, -1)
x = attention_layer(query, key, value)
# Save the attention scores for later visualization.
self.attention_scores = attention_layer.attention_scores
# Apply the FFN.
x = self.ffn(x)
return x
**算法1**(伪代码)通过for循环描述了递归。循环确实简化了实现,但也影响了训练时间。在本节中,我们将自定义递归逻辑封装在CustomRecurrentCell
中。然后,这个自定义单元将通过Keras RNN API封装,使整个代码可向量化。
这个自定义单元,作为keras.layers.Layer
实现,是模型逻辑的组成部分。单元的功能可以分为两部分:- 慢流(时序潜在瓶颈):
AttentionWithFFN
层组成,该层解析前一个慢流的输出,一个中间隐藏表示(即时序潜在瓶颈中的*潜在*)作为查询,以及最新快流的输出作为键和值。此层也可以被理解为*交叉注意力*层。AttentionWithFFN
层组成。该流以顺序方式包含n个SelfAttention
和CrossAttention
层。class CustomRecurrentCell(layers.Layer):
"""Custom Recurrent Cell.
Args:
chunk_size (`int`): Number of tokens in a chunk.
r (`int`): One Cross Attention per **r** Self Attention.
num_layers (`int`): Number of layers.
ffn_dims (`int`): Number of units in FFN.
ffn_dropout (`float`): Dropout probability for FFN.
num_heads (`int`): Number of attention heads.
key_dim (`int`): Size of each attention head for key.
attn_dropout (`float`): Dropout probability for attention module.
"""
def __init__(
self,
chunk_size,
r,
num_layers,
ffn_dims,
ffn_dropout,
num_heads,
key_dim,
attn_dropout,
**kwargs,
):
super().__init__(**kwargs)
# Save the arguments.
self.chunk_size = chunk_size
self.r = r
self.num_layers = num_layers
self.ffn_dims = ffn_dims
self.ffn_droput = ffn_dropout
self.num_heads = num_heads
self.key_dim = key_dim
self.attn_dropout = attn_dropout
# Create state_size. This is important for
# custom recurrence logic.
self.state_size = chunk_size * ffn_dims
self.get_attention_scores = False
self.attention_scores = []
# Perceptual Module
perceptual_module = list()
for layer_idx in range(num_layers):
perceptual_module.append(
AttentionWithFFN(
ffn_dims=ffn_dims,
ffn_dropout=ffn_dropout,
num_heads=num_heads,
key_dim=key_dim,
attn_dropout=attn_dropout,
name=f"pm_self_attn_{layer_idx}",
)
)
if layer_idx % r == 0:
perceptual_module.append(
AttentionWithFFN(
ffn_dims=ffn_dims,
ffn_dropout=ffn_dropout,
num_heads=num_heads,
key_dim=key_dim,
attn_dropout=attn_dropout,
name=f"pm_cross_attn_ffn_{layer_idx}",
)
)
self.perceptual_module = perceptual_module
# Temporal Latent Bottleneck Module
self.tlb_module = AttentionWithFFN(
ffn_dims=ffn_dims,
ffn_dropout=ffn_dropout,
num_heads=num_heads,
key_dim=key_dim,
attn_dropout=attn_dropout,
name=f"tlb_cross_attn_ffn",
)
def build(self, input_shape):
self.built = True
def call(self, inputs, states):
# inputs => (batch, chunk_size, dims)
# states => [(batch, chunk_size, units)]
slow_stream = ops.reshape(states[0], (-1, self.chunk_size, self.ffn_dims))
fast_stream = inputs
for layer_idx, layer in enumerate(self.perceptual_module):
fast_stream = layer(
query=fast_stream, key=fast_stream, value=fast_stream, stream="fast"
)
if layer_idx % self.r == 0:
fast_stream = layer(
query=fast_stream, key=slow_stream, value=slow_stream, stream="slow"
)
slow_stream = self.tlb_module(
query=slow_stream, key=fast_stream, value=fast_stream
)
# Save the attention scores for later visualization.
if self.get_attention_scores:
self.attention_scores.append(self.tlb_module.attention_scores)
return fast_stream, [
ops.reshape(slow_stream, (-1, self.chunk_size * self.ffn_dims))
]
TemporalLatentBottleneckModel
用于封装完整模型在这里,我们只是封装了完整的模型,以便将其暴露用于训练。
class TemporalLatentBottleneckModel(keras.Model):
"""Model Trainer.
Args:
patch_layer ([`keras.layers.Layer`](/api/layers/base_layer#layer-class)): Patching layer.
custom_cell ([`keras.layers.Layer`](/api/layers/base_layer#layer-class)): Custom Recurrent Cell.
"""
def __init__(self, patch_layer, custom_cell, unroll_loops=False, **kwargs):
super().__init__(**kwargs)
self.patch_layer = patch_layer
self.rnn = layers.RNN(custom_cell, unroll=unroll_loops, name="rnn")
self.gap = layers.GlobalAveragePooling1D(name="gap")
self.head = layers.Dense(10, activation="softmax", dtype="float32", name="head")
def call(self, inputs):
x = self.patch_layer(inputs)
x = self.rnn(x)
x = self.gap(x)
outputs = self.head(x)
return outputs
现在,为了开始训练,我们单独定义组件,并将它们作为参数传递给我们的包装类,该类将为训练准备最终模型。我们定义一个 PatchEmbed
层和基于 CustomCell
的 RNN。
# Build the model.
patch_layer = PatchEmbedding(
image_size=(config["image_size"], config["image_size"]),
patch_size=(config["patch_size"], config["patch_size"]),
embed_dim=config["embed_dim"],
chunk_size=config["chunk_size"],
)
custom_rnn_cell = CustomRecurrentCell(
chunk_size=config["chunk_size"],
r=config["r"],
num_layers=config["num_layers"],
ffn_dims=config["embed_dim"],
ffn_dropout=config["ffn_drop"],
num_heads=config["num_heads"],
key_dim=config["embed_dim"],
attn_dropout=config["attn_drop"],
)
model = TemporalLatentBottleneckModel(
patch_layer=patch_layer,
custom_cell=custom_rnn_cell,
)
我们使用AdamW
优化器,因为它在优化方面已在多个基准任务上表现出色。它是keras.optimizers.Adam
优化器的一个版本,并带有权重衰减。
对于损失函数,我们使用keras.losses.SparseCategoricalCrossentropy
函数,它使用预测和实际logits之间的简单交叉熵。我们还计算数据的准确性作为健全性检查。
optimizer = AdamW(
learning_rate=config["learning_rate"], weight_decay=config["weight_decay"]
)
model.compile(
optimizer=optimizer,
loss="sparse_categorical_crossentropy",
metrics=["accuracy"],
)
model.fit()
训练模型我们传递训练数据集并运行训练。
history = model.fit(
train_ds,
epochs=config["epochs"],
validation_data=val_ds,
)
Epoch 1/30
19/19 ━━━━━━━━━━━━━━━━━━━━ 1270s 62s/step - accuracy: 0.1166 - loss: 3.1132 - val_accuracy: 0.1486 - val_loss: 2.2887
Epoch 2/30
19/19 ━━━━━━━━━━━━━━━━━━━━ 1152s 60s/step - accuracy: 0.1798 - loss: 2.2290 - val_accuracy: 0.2249 - val_loss: 2.1083
Epoch 3/30
19/19 ━━━━━━━━━━━━━━━━━━━━ 1150s 60s/step - accuracy: 0.2371 - loss: 2.0661 - val_accuracy: 0.2610 - val_loss: 2.0294
Epoch 4/30
19/19 ━━━━━━━━━━━━━━━━━━━━ 1150s 60s/step - accuracy: 0.2631 - loss: 1.9997 - val_accuracy: 0.2765 - val_loss: 2.0008
Epoch 5/30
19/19 ━━━━━━━━━━━━━━━━━━━━ 1151s 60s/step - accuracy: 0.2869 - loss: 1.9634 - val_accuracy: 0.2985 - val_loss: 1.9578
Epoch 6/30
19/19 ━━━━━━━━━━━━━━━━━━━━ 1151s 60s/step - accuracy: 0.3048 - loss: 1.9314 - val_accuracy: 0.3055 - val_loss: 1.9324
Epoch 7/30
19/19 ━━━━━━━━━━━━━━━━━━━━ 1152s 60s/step - accuracy: 0.3136 - loss: 1.8977 - val_accuracy: 0.3209 - val_loss: 1.9050
Epoch 8/30
19/19 ━━━━━━━━━━━━━━━━━━━━ 1151s 60s/step - accuracy: 0.3238 - loss: 1.8717 - val_accuracy: 0.3231 - val_loss: 1.8874
Epoch 9/30
19/19 ━━━━━━━━━━━━━━━━━━━━ 1152s 60s/step - accuracy: 0.3414 - loss: 1.8453 - val_accuracy: 0.3445 - val_loss: 1.8334
Epoch 10/30
19/19 ━━━━━━━━━━━━━━━━━━━━ 1152s 60s/step - accuracy: 0.3469 - loss: 1.8119 - val_accuracy: 0.3591 - val_loss: 1.8019
Epoch 11/30
19/19 ━━━━━━━━━━━━━━━━━━━━ 1151s 60s/step - accuracy: 0.3648 - loss: 1.7712 - val_accuracy: 0.3793 - val_loss: 1.7513
Epoch 12/30
19/19 ━━━━━━━━━━━━━━━━━━━━ 1146s 60s/step - accuracy: 0.3730 - loss: 1.7332 - val_accuracy: 0.3667 - val_loss: 1.7464
Epoch 13/30
19/19 ━━━━━━━━━━━━━━━━━━━━ 1148s 60s/step - accuracy: 0.3918 - loss: 1.6986 - val_accuracy: 0.3995 - val_loss: 1.6843
Epoch 14/30
19/19 ━━━━━━━━━━━━━━━━━━━━ 1147s 60s/step - accuracy: 0.3975 - loss: 1.6679 - val_accuracy: 0.4026 - val_loss: 1.6602
Epoch 15/30
19/19 ━━━━━━━━━━━━━━━━━━━━ 1146s 60s/step - accuracy: 0.4078 - loss: 1.6400 - val_accuracy: 0.3990 - val_loss: 1.6536
Epoch 16/30
19/19 ━━━━━━━━━━━━━━━━━━━━ 1146s 60s/step - accuracy: 0.4135 - loss: 1.6224 - val_accuracy: 0.4216 - val_loss: 1.6144
Epoch 17/30
19/19 ━━━━━━━━━━━━━━━━━━━━ 1147s 60s/step - accuracy: 0.4254 - loss: 1.5884 - val_accuracy: 0.4281 - val_loss: 1.5788
Epoch 18/30
19/19 ━━━━━━━━━━━━━━━━━━━━ 1146s 60s/step - accuracy: 0.4383 - loss: 1.5614 - val_accuracy: 0.4294 - val_loss: 1.5731
Epoch 19/30
19/19 ━━━━━━━━━━━━━━━━━━━━ 1146s 60s/step - accuracy: 0.4419 - loss: 1.5440 - val_accuracy: 0.4338 - val_loss: 1.5633
Epoch 20/30
19/19 ━━━━━━━━━━━━━━━━━━━━ 1146s 60s/step - accuracy: 0.4439 - loss: 1.5268 - val_accuracy: 0.4430 - val_loss: 1.5211
Epoch 21/30
19/19 ━━━━━━━━━━━━━━━━━━━━ 1147s 60s/step - accuracy: 0.4509 - loss: 1.5108 - val_accuracy: 0.4504 - val_loss: 1.5054
Epoch 22/30
19/19 ━━━━━━━━━━━━━━━━━━━━ 1146s 60s/step - accuracy: 0.4629 - loss: 1.4828 - val_accuracy: 0.4563 - val_loss: 1.4974
Epoch 23/30
19/19 ━━━━━━━━━━━━━━━━━━━━ 1145s 60s/step - accuracy: 0.4660 - loss: 1.4682 - val_accuracy: 0.4647 - val_loss: 1.4794
Epoch 24/30
19/19 ━━━━━━━━━━━━━━━━━━━━ 1146s 60s/step - accuracy: 0.4680 - loss: 1.4524 - val_accuracy: 0.4640 - val_loss: 1.4681
Epoch 25/30
19/19 ━━━━━━━━━━━━━━━━━━━━ 1145s 60s/step - accuracy: 0.4786 - loss: 1.4297 - val_accuracy: 0.4663 - val_loss: 1.4496
Epoch 26/30
19/19 ━━━━━━━━━━━━━━━━━━━━ 1146s 60s/step - accuracy: 0.4889 - loss: 1.4149 - val_accuracy: 0.4769 - val_loss: 1.4350
Epoch 27/30
19/19 ━━━━━━━━━━━━━━━━━━━━ 1146s 60s/step - accuracy: 0.4925 - loss: 1.4009 - val_accuracy: 0.4808 - val_loss: 1.4317
Epoch 28/30
19/19 ━━━━━━━━━━━━━━━━━━━━ 1145s 60s/step - accuracy: 0.4907 - loss: 1.3994 - val_accuracy: 0.4810 - val_loss: 1.4307
Epoch 29/30
19/19 ━━━━━━━━━━━━━━━━━━━━ 1146s 60s/step - accuracy: 0.5000 - loss: 1.3832 - val_accuracy: 0.4844 - val_loss: 1.3996
Epoch 30/30
19/19 ━━━━━━━━━━━━━━━━━━━━ 1146s 60s/step - accuracy: 0.5076 - loss: 1.3592 - val_accuracy: 0.4890 - val_loss: 1.3961
---
## Visualize training metrics
The `model.fit()` will return a `history` object, which stores the values of the metrics
generated during the training run (but it is ephemeral and needs to be saved manually).
We now display the Loss and Accuracy curves for the training and validation sets.
```python
plt.plot(history.history["loss"], label="loss")
plt.plot(history.history["val_loss"], label="val_loss")
plt.legend()
plt.show()
plt.plot(history.history["accuracy"], label="accuracy")
plt.plot(history.history["val_accuracy"], label="val_accuracy")
plt.legend()
plt.show()
def score_to_viz(chunk_score):
# get the most attended token
chunk_viz = ops.max(chunk_score, axis=-2)
# get the mean across heads
chunk_viz = ops.mean(chunk_viz, axis=1)
return chunk_viz
# Get a batch of images and labels from the testing dataset
images, labels = next(iter(test_ds))
# Create a new model instance that is executed eagerly to allow saving
# attention scores. This also requires unrolling loops
eager_model = TemporalLatentBottleneckModel(
patch_layer=patch_layer, custom_cell=custom_rnn_cell, unroll_loops=True
)
eager_model.compile(run_eagerly=True, jit_compile=False)
model.save("weights.keras")
eager_model.load_weights("weights.keras")
# Set the get_attn_scores flag to True
eager_model.rnn.cell.get_attention_scores = True
# Run the model with the testing images and grab the
# attention scores.
outputs = eager_model(images)
list_chunk_scores = eager_model.rnn.cell.attention_scores
# Process the attention scores in order to visualize them
num_chunks = (config["image_size"] // config["patch_size"]) ** 2 // config["chunk_size"]
list_chunk_viz = [score_to_viz(x) for x in list_chunk_scores[-num_chunks:]]
chunk_viz = ops.concatenate(list_chunk_viz, axis=-1)
chunk_viz = ops.reshape(
chunk_viz,
(
config["batch_size"],
config["image_size"] // config["patch_size"],
config["image_size"] // config["patch_size"],
1,
),
)
upsampled_heat_map = layers.UpSampling2D(
size=(4, 4), interpolation="bilinear", dtype="float32"
)(chunk_viz)
# Sample a random image
index = random.randint(0, config["batch_size"])
orig_image = images[index]
overlay_image = upsampled_heat_map[index, ..., 0]
if keras.backend.backend() == "torch":
# when using the torch backend, we are required to ensure that the
# image is copied from the GPU
orig_image = orig_image.cpu().detach().numpy()
overlay_image = overlay_image.cpu().detach().numpy()
# Plot the visualization
fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(10, 5))
ax[0].imshow(orig_image)
ax[0].set_title("Original:")
ax[0].axis("off")
image = ax[1].imshow(orig_image)
ax[1].imshow(
overlay_image,
cmap="inferno",
alpha=0.6,
extent=image.get_extent(),
)
ax[1].set_title("TLB Attention:")
plt.show()