代码示例 / 计算机视觉 / 使用前向-前向算法进行图像分类

使用前向-前向算法进行图像分类

作者: Suvaditya Mukherjee
创建日期 2023/01/08
最后修改日期 2023/01/08
描述: 使用前向-前向算法训练密集层模型。

ⓘ 此示例使用 Keras 2

在 Colab 中查看 GitHub 源代码


简介

以下示例探讨了如何使用前向-前向算法来执行训练,而不是传统上使用的反向传播方法,如 Hinton 在 The Forward-Forward Algorithm: Some Preliminary Investigations (2022) 中提出的。

该概念的灵感来自于对 玻尔兹曼机 的理解。反向传播涉及使用成本函数计算实际输出和预测输出之间的差异,以调整网络权重。另一方面,FF 算法建议使用神经元类比,这些神经元会根据看到图像及其正确对应标签的特定组合而“兴奋”。

该方法从皮层中发生的生物学习过程中获得了一些灵感。该方法带来的一个显着优势是,不再需要通过网络进行反向传播,并且权重更新是针对层本身的局部更新。

由于这仍然是一种实验方法,因此它无法获得最先进的结果。但是,通过适当的调整,它应该可以接近相同的性能。通过本示例,我们将检查一个允许我们在层本身内实现前向-前向算法的过程,而不是依赖于全局损失函数和优化器的传统方法。

本教程的结构如下

  • 执行必要的导入
  • 加载 MNIST 数据集
  • 可视化 MNIST 数据集中的随机样本
  • 定义一个 FFDense 层来覆盖 call 并实现一个自定义的 forwardforward 方法,该方法执行权重更新。
  • 定义一个 FFNetwork 层来覆盖 train_steppredict 并实现两个自定义函数,用于每个样本的预测和标签叠加
  • 将 MNIST 从 NumPy 数组转换为 tf.data.Dataset
  • 拟合网络
  • 可视化结果
  • 对测试样本执行推断

由于此示例需要使用 keras.layers.Layerkeras.models.Model 对某些核心函数进行自定义,因此请参阅以下资源,了解如何执行此操作


设置导入

import tensorflow as tf
from tensorflow import keras
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score
import random
from tensorflow.compiler.tf2xla.python import xla

加载数据集并可视化数据

我们使用 keras.datasets.mnist.load_data() 实用程序直接以 NumPy 数组的形式拉取 MNIST 数据集。然后,我们将它排列成训练和测试分割的形式。

在加载数据集后,我们从训练集中选择 4 个随机样本,并使用 matplotlib.pyplot 可视化它们。

(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

print("4 Random Training samples and labels")
idx1, idx2, idx3, idx4 = random.sample(range(0, x_train.shape[0]), 4)

img1 = (x_train[idx1], y_train[idx1])
img2 = (x_train[idx2], y_train[idx2])
img3 = (x_train[idx3], y_train[idx3])
img4 = (x_train[idx4], y_train[idx4])

imgs = [img1, img2, img3, img4]

plt.figure(figsize=(10, 10))

for idx, item in enumerate(imgs):
    image, label = item[0], item[1]
    plt.subplot(2, 2, idx + 1)
    plt.imshow(image, cmap="gray")
    plt.title(f"Label : {label}")
plt.show()
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
11490434/11490434 [==============================] - 0s 0us/step
4 Random Training samples and labels

png


定义 FFDense 自定义层

在这个自定义层中,我们有一个基本 keras.layers.Dense 对象,它充当层内的基本 Dense 层。由于权重更新将发生在层本身内,因此我们添加了一个 keras.optimizers.Optimizer 对象,该对象由用户接受。在这里,我们使用 Adam 作为我们的优化器,学习率相当高,为 0.03

遵循算法的具体细节,我们必须设置一个 threshold 参数,该参数将用于在每个预测中做出正负决策。这被设置为默认的 2.0。由于 epoch 是针对层本身的本地化,因此我们还设置了 num_epochs 参数(默认为 50)。

我们覆盖了 call 方法,以便对整个输入空间执行归一化,然后将其通过基本 Dense 层运行,就像在普通 Dense 层调用中一样。

我们实现了 Forward-Forward 算法,它接受两种类型的输入张量,分别代表正样本和负样本。我们使用 tf.GradientTape() 在这里编写了一个自定义训练循环,在循环中,我们通过计算预测值与阈值之间的距离来理解误差,并取其平均值得到 mean_loss 指标,从而计算每个样本的损失。

借助 tf.GradientTape(),我们计算可训练的基础 Dense 层的梯度更新,并使用该层的本地优化器应用它们。

最后,我们将 call 结果作为正样本和负样本的 Dense 结果返回,同时还返回最后一个 mean_loss 指标和整个 epoch 运行期间的所有损失值。

class FFDense(keras.layers.Layer):
    """
    A custom ForwardForward-enabled Dense layer. It has an implementation of the
    Forward-Forward network internally for use.
    This layer must be used in conjunction with the `FFNetwork` model.
    """

    def __init__(
        self,
        units,
        optimizer,
        loss_metric,
        num_epochs=50,
        use_bias=True,
        kernel_initializer="glorot_uniform",
        bias_initializer="zeros",
        kernel_regularizer=None,
        bias_regularizer=None,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.dense = keras.layers.Dense(
            units=units,
            use_bias=use_bias,
            kernel_initializer=kernel_initializer,
            bias_initializer=bias_initializer,
            kernel_regularizer=kernel_regularizer,
            bias_regularizer=bias_regularizer,
        )
        self.relu = keras.layers.ReLU()
        self.optimizer = optimizer
        self.loss_metric = loss_metric
        self.threshold = 1.5
        self.num_epochs = num_epochs

    # We perform a normalization step before we run the input through the Dense
    # layer.

    def call(self, x):
        x_norm = tf.norm(x, ord=2, axis=1, keepdims=True)
        x_norm = x_norm + 1e-4
        x_dir = x / x_norm
        res = self.dense(x_dir)
        return self.relu(res)

    # The Forward-Forward algorithm is below. We first perform the Dense-layer
    # operation and then get a Mean Square value for all positive and negative
    # samples respectively.
    # The custom loss function finds the distance between the Mean-squared
    # result and the threshold value we set (a hyperparameter) that will define
    # whether the prediction is positive or negative in nature. Once the loss is
    # calculated, we get a mean across the entire batch combined and perform a
    # gradient calculation and optimization step. This does not technically
    # qualify as backpropagation since there is no gradient being
    # sent to any previous layer and is completely local in nature.

    def forward_forward(self, x_pos, x_neg):
        for i in range(self.num_epochs):
            with tf.GradientTape() as tape:
                g_pos = tf.math.reduce_mean(tf.math.pow(self.call(x_pos), 2), 1)
                g_neg = tf.math.reduce_mean(tf.math.pow(self.call(x_neg), 2), 1)

                loss = tf.math.log(
                    1
                    + tf.math.exp(
                        tf.concat([-g_pos + self.threshold, g_neg - self.threshold], 0)
                    )
                )
                mean_loss = tf.cast(tf.math.reduce_mean(loss), tf.float32)
                self.loss_metric.update_state([mean_loss])
            gradients = tape.gradient(mean_loss, self.dense.trainable_weights)
            self.optimizer.apply_gradients(zip(gradients, self.dense.trainable_weights))
        return (
            tf.stop_gradient(self.call(x_pos)),
            tf.stop_gradient(self.call(x_neg)),
            self.loss_metric.result(),
        )

定义 FFNetwork 自定义模型

定义了自定义层后,我们还需要覆盖 train_step 方法并定义一个自定义 keras.models.Model,它与我们的 FFDense 层配合使用。

对于该算法,我们必须将标签“嵌入”到原始图像中。为此,我们利用 MNIST 图像的结构,其中左上角 10 个像素始终为零。我们将它用作标签空间,以便在图像本身内以视觉方式进行 one-hot 编码标签。此操作由 overlay_y_on_x 函数执行。

我们将预测函数分解为一个按样本预测函数,然后由覆盖的 predict() 函数在整个测试集上调用该函数。预测在这里借助于测量每幅图像每层神经元的 excitation 进行。然后将其对所有层求和以计算网络范围的“良好度得分”。然后选择具有最高“良好度得分”的标签作为样本预测。

train_step 函数被覆盖,用作根据每层的 epoch 数对每层运行训练的主要控制循环。

class FFNetwork(keras.Model):
    """
    A [`keras.Model`](/api/models/model#model-class) that supports a `FFDense` network creation. This model
    can work for any kind of classification task. It has an internal
    implementation with some details specific to the MNIST dataset which can be
    changed as per the use-case.
    """

    # Since each layer runs gradient-calculation and optimization locally, each
    # layer has its own optimizer that we pass. As a standard choice, we pass
    # the `Adam` optimizer with a default learning rate of 0.03 as that was
    # found to be the best rate after experimentation.
    # Loss is tracked using `loss_var` and `loss_count` variables.
    # Use legacy optimizer for Layer Optimizer to fix issue
    # https://github.com/keras-team/keras-io/issues/1241

    def __init__(
        self,
        dims,
        layer_optimizer=keras.optimizers.legacy.Adam(learning_rate=0.03),
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.layer_optimizer = layer_optimizer
        self.loss_var = tf.Variable(0.0, trainable=False, dtype=tf.float32)
        self.loss_count = tf.Variable(0.0, trainable=False, dtype=tf.float32)
        self.layer_list = [keras.Input(shape=(dims[0],))]
        for d in range(len(dims) - 1):
            self.layer_list += [
                FFDense(
                    dims[d + 1],
                    optimizer=self.layer_optimizer,
                    loss_metric=keras.metrics.Mean(),
                )
            ]

    # This function makes a dynamic change to the image wherein the labels are
    # put on top of the original image (for this example, as MNIST has 10
    # unique labels, we take the top-left corner's first 10 pixels). This
    # function returns the original data tensor with the first 10 pixels being
    # a pixel-based one-hot representation of the labels.

    @tf.function(reduce_retracing=True)
    def overlay_y_on_x(self, data):
        X_sample, y_sample = data
        max_sample = tf.reduce_max(X_sample, axis=0, keepdims=True)
        max_sample = tf.cast(max_sample, dtype=tf.float64)
        X_zeros = tf.zeros([10], dtype=tf.float64)
        X_update = xla.dynamic_update_slice(X_zeros, max_sample, [y_sample])
        X_sample = xla.dynamic_update_slice(X_sample, X_update, [0])
        return X_sample, y_sample

    # A custom `predict_one_sample` performs predictions by passing the images
    # through the network, measures the results produced by each layer (i.e.
    # how high/low the output values are with respect to the set threshold for
    # each label) and then simply finding the label with the highest values.
    # In such a case, the images are tested for their 'goodness' with all
    # labels.

    @tf.function(reduce_retracing=True)
    def predict_one_sample(self, x):
        goodness_per_label = []
        x = tf.reshape(x, [tf.shape(x)[0] * tf.shape(x)[1]])
        for label in range(10):
            h, label = self.overlay_y_on_x(data=(x, label))
            h = tf.reshape(h, [-1, tf.shape(h)[0]])
            goodness = []
            for layer_idx in range(1, len(self.layer_list)):
                layer = self.layer_list[layer_idx]
                h = layer(h)
                goodness += [tf.math.reduce_mean(tf.math.pow(h, 2), 1)]
            goodness_per_label += [
                tf.expand_dims(tf.reduce_sum(goodness, keepdims=True), 1)
            ]
        goodness_per_label = tf.concat(goodness_per_label, 1)
        return tf.cast(tf.argmax(goodness_per_label, 1), tf.float64)

    def predict(self, data):
        x = data
        preds = list()
        preds = tf.map_fn(fn=self.predict_one_sample, elems=x)
        return np.asarray(preds, dtype=int)

    # This custom `train_step` function overrides the internal `train_step`
    # implementation. We take all the input image tensors, flatten them and
    # subsequently produce positive and negative samples on the images.
    # A positive sample is an image that has the right label encoded on it with
    # the `overlay_y_on_x` function. A negative sample is an image that has an
    # erroneous label present on it.
    # With the samples ready, we pass them through each `FFLayer` and perform
    # the Forward-Forward computation on it. The returned loss is the final
    # loss value over all the layers.

    @tf.function(jit_compile=True)
    def train_step(self, data):
        x, y = data

        # Flatten op
        x = tf.reshape(x, [-1, tf.shape(x)[1] * tf.shape(x)[2]])

        x_pos, y = tf.map_fn(fn=self.overlay_y_on_x, elems=(x, y))

        random_y = tf.random.shuffle(y)
        x_neg, y = tf.map_fn(fn=self.overlay_y_on_x, elems=(x, random_y))

        h_pos, h_neg = x_pos, x_neg

        for idx, layer in enumerate(self.layers):
            if isinstance(layer, FFDense):
                print(f"Training layer {idx+1} now : ")
                h_pos, h_neg, loss = layer.forward_forward(h_pos, h_neg)
                self.loss_var.assign_add(loss)
                self.loss_count.assign_add(1.0)
            else:
                print(f"Passing layer {idx+1} now : ")
                x = layer(x)
        mean_res = tf.math.divide(self.loss_var, self.loss_count)
        return {"FinalLoss": mean_res}

将 MNIST NumPy 数组转换为 tf.data.Dataset

我们现在对 NumPy 数组进行一些初步处理,然后将它们转换为 tf.data.Dataset 格式,从而可以实现优化加载。

x_train = x_train.astype(float) / 255
x_test = x_test.astype(float) / 255
y_train = y_train.astype(int)
y_test = y_test.astype(int)

train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test))

train_dataset = train_dataset.batch(60000)
test_dataset = test_dataset.batch(10000)

拟合网络并可视化结果

完成所有之前的设置后,我们现在将运行 model.fit() 并运行 250 个模型 epoch,这将在每层上运行 50*250 个 epoch。我们将在训练每层时看到绘制的损失曲线。

model = FFNetwork(dims=[784, 500, 500])

model.compile(
    optimizer=keras.optimizers.Adam(learning_rate=0.03),
    loss="mse",
    jit_compile=True,
    metrics=[keras.metrics.Mean()],
)

epochs = 250
history = model.fit(train_dataset, epochs=epochs)
Epoch 1/250
Training layer 1 now : 
Training layer 2 now : 
Training layer 1 now : 
Training layer 2 now : 
1/1 [==============================] - 72s 72s/step - FinalLoss: 0.7279
Epoch 2/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.7082
Epoch 3/250
1/1 [==============================] - 6s 6s/step - FinalLoss: 0.7031
Epoch 4/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.6806
Epoch 5/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.6564
Epoch 6/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.6333
Epoch 7/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.6126
Epoch 8/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.5946
Epoch 9/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.5786
Epoch 10/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.5644
Epoch 11/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.5518
Epoch 12/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.5405
Epoch 13/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.5301
Epoch 14/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.5207
Epoch 15/250
1/1 [==============================] - 6s 6s/step - FinalLoss: 0.5122
Epoch 16/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.5044
Epoch 17/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.4972
Epoch 18/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.4906
Epoch 19/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.4845
Epoch 20/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.4787
Epoch 21/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.4734
Epoch 22/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.4685
Epoch 23/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.4639
Epoch 24/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.4596
Epoch 25/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.4555
Epoch 26/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.4516
Epoch 27/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.4479
Epoch 28/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.4445
Epoch 29/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.4411
Epoch 30/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.4380
Epoch 31/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.4350
Epoch 32/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.4322
Epoch 33/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.4295
Epoch 34/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.4269
Epoch 35/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.4245
Epoch 36/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.4222
Epoch 37/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.4199
Epoch 38/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.4178
Epoch 39/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.4157
Epoch 40/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.4136
Epoch 41/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.4117
Epoch 42/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.4098
Epoch 43/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.4079
Epoch 44/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.4062
Epoch 45/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.4045
Epoch 46/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.4028
Epoch 47/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.4012
Epoch 48/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3996
Epoch 49/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3982
Epoch 50/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3967
Epoch 51/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3952
Epoch 52/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3938
Epoch 53/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3925
Epoch 54/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3912
Epoch 55/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3899
Epoch 56/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3886
Epoch 57/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3874
Epoch 58/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3862
Epoch 59/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3851
Epoch 60/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3840
Epoch 61/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3829
Epoch 62/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3818
Epoch 63/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3807
Epoch 64/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3797
Epoch 65/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3787
Epoch 66/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3777
Epoch 67/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3767
Epoch 68/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3758
Epoch 69/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3748
Epoch 70/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3739
Epoch 71/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3730
Epoch 72/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3721
Epoch 73/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3712
Epoch 74/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3704
Epoch 75/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3695
Epoch 76/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3688
Epoch 77/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3680
Epoch 78/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3671
Epoch 79/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3664
Epoch 80/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3656
Epoch 81/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3648
Epoch 82/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3641
Epoch 83/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3634
Epoch 84/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3627
Epoch 85/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3620
Epoch 86/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3613
Epoch 87/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3606
Epoch 88/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3599
Epoch 89/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3593
Epoch 90/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3586
Epoch 91/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3580
Epoch 92/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3574
Epoch 93/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3568
Epoch 94/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3561
Epoch 95/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3555
Epoch 96/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3549
Epoch 97/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3544
Epoch 98/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3538
Epoch 99/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3532
Epoch 100/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3526
Epoch 101/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3521
Epoch 102/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3515
Epoch 103/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3510
Epoch 104/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3505
Epoch 105/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3499
Epoch 106/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3494
Epoch 107/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3489
Epoch 108/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3484
Epoch 109/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3478
Epoch 110/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3474
Epoch 111/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3468
Epoch 112/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3464
Epoch 113/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3459
Epoch 114/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3454
Epoch 115/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3450
Epoch 116/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3445
Epoch 117/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3440
Epoch 118/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3436
Epoch 119/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3432
Epoch 120/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3427
Epoch 121/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3423
Epoch 122/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3419
Epoch 123/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3414
Epoch 124/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3410
Epoch 125/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3406
Epoch 126/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3402
Epoch 127/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3398
Epoch 128/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3394
Epoch 129/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3390
Epoch 130/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3386
Epoch 131/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3382
Epoch 132/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3378
Epoch 133/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3375
Epoch 134/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3371
Epoch 135/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3368
Epoch 136/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3364
Epoch 137/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3360
Epoch 138/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3357
Epoch 139/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3353
Epoch 140/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3350
Epoch 141/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3346
Epoch 142/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3343
Epoch 143/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3339
Epoch 144/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3336
Epoch 145/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3333
Epoch 146/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3329
Epoch 147/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3326
Epoch 148/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3323
Epoch 149/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3320
Epoch 150/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3317
Epoch 151/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3313
Epoch 152/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3310
Epoch 153/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3307
Epoch 154/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3304
Epoch 155/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3302
Epoch 156/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3299
Epoch 157/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3296
Epoch 158/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3293
Epoch 159/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3290
Epoch 160/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3287
Epoch 161/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3284
Epoch 162/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3281
Epoch 163/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3279
Epoch 164/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3276
Epoch 165/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3273
Epoch 166/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3270
Epoch 167/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3268
Epoch 168/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3265
Epoch 169/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3262
Epoch 170/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3260
Epoch 171/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3257
Epoch 172/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3255
Epoch 173/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3252
Epoch 174/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3250
Epoch 175/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3247
Epoch 176/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3244
Epoch 177/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3242
Epoch 178/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3240
Epoch 179/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3237
Epoch 180/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3235
Epoch 181/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3232
Epoch 182/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3230
Epoch 183/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3228
Epoch 184/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3225
Epoch 185/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3223
Epoch 186/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3221
Epoch 187/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3219
Epoch 188/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3216
Epoch 189/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3214
Epoch 190/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3212
Epoch 191/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3210
Epoch 192/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3208
Epoch 193/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3205
Epoch 194/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3203
Epoch 195/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3201
Epoch 196/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3199
Epoch 197/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3197
Epoch 198/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3195
Epoch 199/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3193
Epoch 200/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3191
Epoch 201/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3189
Epoch 202/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3187
Epoch 203/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3185
Epoch 204/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3183
Epoch 205/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3181
Epoch 206/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3179
Epoch 207/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3177
Epoch 208/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3175
Epoch 209/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3174
Epoch 210/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3172
Epoch 211/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3170
Epoch 212/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3168
Epoch 213/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3166
Epoch 214/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3165
Epoch 215/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3163
Epoch 216/250
1/1 [==============================] - 6s 6s/step - FinalLoss: 0.3161
Epoch 217/250
1/1 [==============================] - 6s 6s/step - FinalLoss: 0.3159
Epoch 218/250
1/1 [==============================] - 6s 6s/step - FinalLoss: 0.3157
Epoch 219/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3155
Epoch 220/250
1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3154
Epoch 221/250
1/1 [==============================] - 6s 6s/step - FinalLoss: 0.3152
Epoch 222/250
1/1 [==============================] - 6s 6s/step - FinalLoss: 0.3150
Epoch 223/250
1/1 [==============================] - 6s 6s/step - FinalLoss: 0.3148
Epoch 224/250
1/1 [==============================] - 6s 6s/step - FinalLoss: 0.3147
Epoch 225/250
1/1 [==============================] - 6s 6s/step - FinalLoss: 0.3145
Epoch 226/250
1/1 [==============================] - 6s 6s/step - FinalLoss: 0.3143
Epoch 227/250
1/1 [==============================] - 6s 6s/step - FinalLoss: 0.3142
Epoch 228/250
1/1 [==============================] - 6s 6s/step - FinalLoss: 0.3140
Epoch 229/250
1/1 [==============================] - 6s 6s/step - FinalLoss: 0.3139
Epoch 230/250
1/1 [==============================] - 6s 6s/step - FinalLoss: 0.3137
Epoch 231/250
1/1 [==============================] - 6s 6s/step - FinalLoss: 0.3135
Epoch 232/250
1/1 [==============================] - 6s 6s/step - FinalLoss: 0.3134
Epoch 233/250
1/1 [==============================] - 6s 6s/step - FinalLoss: 0.3132
Epoch 234/250
1/1 [==============================] - 6s 6s/step - FinalLoss: 0.3131
Epoch 235/250
1/1 [==============================] - 6s 6s/step - FinalLoss: 0.3129
Epoch 236/250
1/1 [==============================] - 6s 6s/step - FinalLoss: 0.3127
Epoch 237/250
1/1 [==============================] - 6s 6s/step - FinalLoss: 0.3126
Epoch 238/250
1/1 [==============================] - 6s 6s/step - FinalLoss: 0.3124
Epoch 239/250
1/1 [==============================] - 6s 6s/step - FinalLoss: 0.3123
Epoch 240/250
1/1 [==============================] - 6s 6s/step - FinalLoss: 0.3121
Epoch 241/250
1/1 [==============================] - 6s 6s/step - FinalLoss: 0.3120
Epoch 242/250
1/1 [==============================] - 6s 6s/step - FinalLoss: 0.3118
Epoch 243/250
1/1 [==============================] - 6s 6s/step - FinalLoss: 0.3117
Epoch 244/250
1/1 [==============================] - 6s 6s/step - FinalLoss: 0.3116
Epoch 245/250
1/1 [==============================] - 6s 6s/step - FinalLoss: 0.3114
Epoch 246/250
1/1 [==============================] - 6s 6s/step - FinalLoss: 0.3113
Epoch 247/250
1/1 [==============================] - 6s 6s/step - FinalLoss: 0.3111
Epoch 248/250
1/1 [==============================] - 6s 6s/step - FinalLoss: 0.3110
Epoch 249/250
1/1 [==============================] - 6s 6s/step - FinalLoss: 0.3108
Epoch 250/250
1/1 [==============================] - 6s 6s/step - FinalLoss: 0.3107

执行推断和测试

在对模型进行了很大程度的训练后,我们现在将看看它在测试集上的表现。我们计算准确率得分以仔细了解结果。

preds = model.predict(tf.convert_to_tensor(x_test))

preds = preds.reshape((preds.shape[0], preds.shape[1]))

results = accuracy_score(preds, y_test)

print(f"Test Accuracy score : {results*100}%")

plt.plot(range(len(history.history["FinalLoss"])), history.history["FinalLoss"])
plt.title("Loss over training")
plt.show()
Test Accuracy score : 97.64%

png


结论

本示例演示了如何使用 TensorFlow 和 Keras 包来实现 Forward-Forward 算法。虽然 Hinton 教授在其论文中提出的调查结果目前仍限于 MNIST 和 Fashion-MNIST 等较小的模型和数据集,但预计未来论文将在更大模型(如 LLMs)上提供后续结果。

通过该论文,Hinton 教授报告了在 60 个 epoch 上运行的 2000 个单元、4 个隐藏层、全连接网络的测试准确率误差为 1.36%(同时提到反向传播只需 20 个 epoch 即可实现类似的性能)。将学习率提高一倍并在 40 个 epoch 内进行训练的另一运行产生了略差的错误率,为 1.46%。

当前示例没有产生最先进的结果。但是,通过对学习率、模型架构(Dense 层中的单元数量、核激活、初始化、正则化等)进行适当的调整,可以改进结果以匹配论文中的主张。