代码示例 / 时间序列 / 脑电信号分类用于脑机接口

脑电信号分类用于脑机接口

作者: Okba Bekhelifi
创建日期 2025/01/08
最后修改日期 2025/01/08
描述: 一个基于 Transformer 的 EEG 信号分类模型,用于脑机接口。

ⓘ 此示例使用 Keras 2

在 Colab 中查看 GitHub 源代码

引言

本教程将解释如何构建一个基于 Transformer 的神经网络,用于分类在稳态视觉诱发电位 (SSVEPs) 实验中记录的脑机接口 (BCI) 脑电图 (EEG) 数据,应用于脑控拼写器。

本教程复现了 SSVEPFormer 研究 [1]( arXiv 预印本 / 经同行评审的论文 )中的一项实验。该模型是第一个引入用于 SSVEP 数据分类的基于 Transformer 的模型,我们将使用 Nakanishi 等人 [2] 的公开数据集(作为论文中的数据集 1)对其进行测试。

该过程遵循受试者间分类实验。给定数据集中的 N 个受试者数据,训练数据分区包含 N-1 个受试者的数据,剩余的单个受试者数据用于测试。训练集不包含来自测试受试者的任何样本。通过这种方式,我们构建了一个真正的受试者独立模型。我们在从预处理到训练的所有处理操作中,都保持与原始论文相同的参数和设置。

本教程首先快速介绍 BCI 和数据集,然后按照以下部分介绍技术细节: - 设置和导入。 - 数据集下载和解压。 - 数据预处理:EEG 数据滤波、分段、原始数据和滤波后数据的可视化,以及表现良好参与者的频率响应。 - 层和模型创建。 - 评估:以单个参与者数据分类为例,然后是所有参与者的数据分类。 - 可视化:我们展示了 Keras 3 可用后端(JAX、Tensorflow 和 PyTorch)在三种不同 GPU 上的训练和推理时间比较结果。 - 结论:最终讨论和说明。

数据集描述


BCI 和 SSVEP

BCI 提供了仅使用大脑活动进行通信的能力,这可以通过外部刺激实现,这些刺激产生特定的响应,指示受试者的意图。当用户将注意力集中在目标刺激上时,就会引发这些响应。我们可以使用视觉刺激,通常在显示器上以网格形式向受试者展示一组选项,以便一次选择一个命令。每个刺激都会按照固定的频率和相位闪烁,记录在皮层枕叶和枕顶叶区域(视觉皮层)的由此产生的 EEG,在受试者注视的刺激相关频率上将具有更高的功率。这种类型的 BCI 范式称为稳态视觉诱发电位 (SSVEPs),由于其在分类中的可靠性、高性能和快速性(1 秒的 EEG 就足以发出一个命令),已广泛应用于多种场景。还存在其他类型的大脑响应,不需要外部刺激,但它们的可靠性较低。 演示视频

本教程使用了包含 12 个命令(类别)的公开 SSVEP 数据集 [2],其界面模拟了电话拨号。 数据集

该数据集记录了 10 名参与者的数据,每名参与者面对上述 12 个 SSVEP 刺激(A)。刺激频率范围从 9.25Hz 到 14.75 Hz,步长为 0.5Hz;每行的相位范围从 0 到 1.5 π,步长为 0.5 π(B)。EEG 信号使用 8 个电极(通道)(PO7、PO3、POz、PO4、PO8、O1、Oz、O2)采集,采样频率为 2048 Hz,然后将存储的数据下采样到 256 Hz。受试者完成了 15 个记录块,每个记录块包含 12 个随机顺序的刺激(每个类别 1 个),每个刺激持续 4 秒。总共,每个受试者进行了 180 次试验。

设置


选择 JAX 后端

import os

os.environ["KERAS_BACKEND"] = "jax"

安装依赖

!pip install -q numpy
!pip install -q scipy
!pip install -q matplotlib

导入

# deep learning libraries
from keras import backend as K
from keras import layers
import keras

# visualization and signal processing imports
import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np
from scipy.signal import butter, filtfilt
from scipy.io import loadmat

# setting the backend, seed and Keras channel format
K.set_image_data_format("channels_first")
keras.utils.set_random_seed(42)

下载并解压数据集


Nakanishi 等人 2015 数据集仓库 仓库

!curl -O https://sccn.ucsd.edu/download/cca_ssvep.zip
!unzip cca_ssvep.zip
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed

0 0 0 0 0 0 0 0 –:–:– –:–:– –:–:– 0 0 0 0 0 0 0 0 0 –:–:– –:–:– –:–:– 0

0 145M 0 49152 0 0 40897 0 1:02:11 0:00:01 1:02:10 40891

1 145M 1 2480k 0 0 1140k 0 0:02:10 0:00:02 0:02:08 1140k

9 145M 9 13.4M 0 0 4371k 0 0:00:34 0:00:03 0:00:31 4371k

17 145M 17 25.9M 0 0 6201k 0 0:00:24 0:00:04 0:00:20 6200k

25 145M 25 37.2M 0 0 7398k 0 0:00:20 0:00:05 0:00:15 7632k

33 145M 33 48.9M 0 0 8052k 0 0:00:18 0:00:06 0:00:12 9972k

41 145M 41 60.3M 0 0 8586k 0 0:00:17 0:00:07 0:00:10 11.5M

49 145M 49 71.9M 0 0 9014k 0 0:00:16 0:00:08 0:00:08 11.6M

57 145M 57 84.2M 0 0 9423k 0 0:00:15 0:00:09 0:00:06 11.9M

66 145M 66 96.5M 0 0 9709k 0 0:00:15 0:00:10 0:00:05 11.8M

74 145M 74 108M 0 0 9938k 0 0:00:14 0:00:11 0:00:03 12.0M

82 145M 82 119M 0 0 9.8M 0 0:00:14 0:00:12 0:00:02 11.8M

90 145M 90 131M 0 0 9.9M 0 0:00:14 0:00:13 0:00:01 11.8M

98 145M 98 142M 0 0 10.0M 0 0:00:14 0:00:14 –:–:– 11.7M 100 145M 100 145M 0 0 10.1M 0 0:00:14 0:00:14 –:–:– 11.7M

Archive:  cca_ssvep.zip
   creating: cca_ssvep/
  inflating: cca_ssvep/s4.mat        
  inflating: cca_ssvep/s5.mat        
  inflating: cca_ssvep/s3.mat        
  inflating: cca_ssvep/s7.mat        
  inflating: cca_ssvep/chan_locs.pdf  
  inflating: cca_ssvep/readme.txt    
  inflating: cca_ssvep/s2.mat        
  inflating: cca_ssvep/s8.mat        
  inflating: cca_ssvep/s10.mat       
  inflating: cca_ssvep/s9.mat        
  inflating: cca_ssvep/s6.mat        
  inflating: cca_ssvep/s1.mat        

预处理

遵循的预处理步骤首先是读取每个受试者的 EEG 数据,然后在包含大多数有用信息的频率区间内对原始数据进行滤波,然后我们从刺激开始时选择固定持续时间的信号(由于视觉系统引起的延迟,我们在刺激开始时增加 135 毫秒)。最后,所有受试者数据被连接成一个 Tensor,形状为:[受试者数 x 样本数 x 通道数 x 试验次数]。数据标签也按照实验中试验的顺序连接,并将是一个形状为 [受试者数 x 试验次数] 的矩阵(此处通道指的是电极,整个教程中都使用此表示法)。

def raw_signal(folder, fs=256, duration=1.0, onset=0.135):
    """selecting a 1-second segment of the raw EEG signal for
    subject 1.
    """
    onset = 38 + int(onset * fs)
    end = int(duration * fs)
    data = loadmat(f"{folder}/s1.mat")
    # samples, channels, trials, targets
    eeg = data["eeg"].transpose((2, 1, 3, 0))
    # segment data
    eeg = eeg[onset : onset + end, :, :, :]
    return eeg


def segment_eeg(
    folder, elecs=None, fs=256, duration=1.0, band=[5.0, 45.0], order=4, onset=0.135
):
    """Filtering and segmenting EEG signals for all subjects."""
    n_subejects = 10
    onset = 38 + int(onset * fs)
    end = int(duration * fs)
    X, Y = [], []  # empty data and labels

    for subj in range(1, n_subejects + 1):
        data = loadmat(f"{data_folder}/s{subj}.mat")
        # samples, channels, trials, targets
        eeg = data["eeg"].transpose((2, 1, 3, 0))
        # filter data
        eeg = filter_eeg(eeg, fs=fs, band=band, order=order)
        # segment data
        eeg = eeg[onset : onset + end, :, :, :]
        # reshape labels
        samples, channels, blocks, targets = eeg.shape
        y = np.tile(np.arange(1, targets + 1), (blocks, 1))
        y = y.reshape((1, blocks * targets), order="F")

        X.append(eeg.reshape((samples, channels, blocks * targets), order="F"))
        Y.append(y)

    X = np.array(X, dtype=np.float32, order="F")
    Y = np.array(Y, dtype=np.float32).squeeze()

    return X, Y


def filter_eeg(data, fs=256, band=[5.0, 45.0], order=4):
    """Filter EEG signal using a zero-phase IIR filter"""
    B, A = butter(order, np.array(band) / (fs / 2), btype="bandpass")
    return filtfilt(B, A, data, axis=0)

将数据分段为 Epoch

data_folder = os.path.abspath("./cca_ssvep")
band = [8, 64]  # low-frequency / high-frequency cutoffS
order = 4  # filter order
fs = 256  # sampling frequency
duration = 1.0  # 1 second

# raw signal
X_raw = raw_signal(data_folder, fs=fs, duration=duration)
print(
    f"A single subject raw EEG (X_raw) shape: {X_raw.shape} [Samples x Channels x Blocks x Targets]"
)

# segmented signal
X, Y = segment_eeg(data_folder, band=band, order=order, fs=fs, duration=duration)
print(
    f"Full training data (X) shape: {X.shape} [Subject x Samples x Channels x Trials]"
)
print(f"data labels (Y) shape:        {Y.shape} [Subject x Trials]")

samples = X.shape[1]
time = np.linspace(0.0, samples / fs, samples) * 1000
A single subject raw EEG (X_raw) shape: (256, 8, 15, 12) [Samples x Channels x Blocks x Targets]

Full training data (X) shape: (10, 256, 8, 180) [Subject x Samples x Channels x Trials]
data labels (Y) shape:        (10, 180) [Subject x Trials]

可视化 EEG 信号


时域 EEG

原始 EEG 与滤波后 EEG 图示了受试者 s1 在 Oz(视觉皮层中央电极,位于头部后部)的同一段 1 秒记录。左侧是记录到的原始 EEG,右侧是在 [8, 64] Hz 频段滤波后的 EEG。我们看到噪声更少,并且振幅值被归一化到自然的 EEG 范围内。

elec = 6  # Oz channel

x_label = "Time (ms)"
y_label = "Voltage (uV)"
# Create subplots
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4))

# Plot data on the first subplot
ax1.plot(time, X_raw[:, elec, 0, 0], "r-")
ax1.set_xlabel(x_label)
ax1.set_ylabel(y_label)
ax1.set_title("Raw EEG : 1 second at Oz ")

# Plot data on the second subplot
ax2.plot(time, X[0, :, elec, 0], "b-")
ax2.set_xlabel(x_label)
ax2.set_ylabel(y_label)
ax2.set_title("Filtered EEG between 8-64 Hz: 1 second at Oz")

# Adjust spacing between subplots
plt.tight_layout()

# Show the plot
plt.show()

png


EEG 频率表示

使用 Welch 方法,我们可视化了表现良好受试者在 Oz 电极上对每个刺激进行完整 4 秒 EEG 记录的频率功率。红色峰值表示刺激的基频和二次谐波(基频的两倍)。我们看到清晰的峰值,显示了该受试者的高响应,这意味着该受试者是进行 SSVEP BCI 控制的良好候选者。在许多情况下,峰值较弱或不存在,这意味着受试者未能正确完成任务。

eeg_frequency

创建层和模型

以跨框架自定义组件的方式创建层。在 SSVEPFormer 中,数据首先通过快速傅里叶变换 (FFT) 转换到频域,以构建一个复杂的频谱表示,该表示由固定频段内的频率和相位信息的连接组成。为了使模型保持端到端格式,我们将复杂频谱变换实现为不可训练层。

模型 SSVEPFormer 与 Transformer 架构不同之处在于,它不包含位置编码/嵌入层,而是由一个通道组合块代替。该通道组合块包含一个核大小为 1 的 Conv1D 层,其滤波器数量是输入通道(电极数量的两倍)的两倍,以及 LayerNorm、Gelu 激活和 Dropout。与 Transformers 的另一个区别是缺少带注意力机制的多头注意力层。模型的编码器包含两个相同且连续的块。每个块包含 CNN 模块和 MLP 模块两个子块。CNN 模块包含一个 LayerNorm、一个滤波器数量与通道组合相同的 Conv1D 层、LayerNorm、Gelu、Dropout 和一个残差连接。MLP 模块包含一个 LayerNorm、一个 Dense 层、Gelu、Dropout 和一个残差连接。Dense 层分别应用于每个通道。模型的最后一个块是 MLP 头,包含 Flatten 层、Dropout、Dense 层、LayerNorm、Gelu、Dropout 以及带 softmax 激活的 Dense 层。所有可训练权重都按照原始论文所述,由均值为 0、标准差为 0.01 的正态分布进行初始化。

class ComplexSpectrum(keras.layers.Layer):
    def __init__(self, nfft=512, fft_start=8, fft_end=64):
        super().__init__()
        self.nfft = nfft
        self.fft_start = fft_start
        self.fft_end = fft_end

    def call(self, x):
        samples = x.shape[-1]
        x = keras.ops.rfft(x, fft_length=self.nfft)
        real = x[0] / samples
        imag = x[1] / samples
        real = real[:, :, self.fft_start : self.fft_end]
        imag = imag[:, :, self.fft_start : self.fft_end]
        x = keras.ops.concatenate((real, imag), axis=-1)
        return x


class ChannelComb(keras.layers.Layer):
    def __init__(self, n_channels, drop_rate=0.5):
        super().__init__()
        self.conv = layers.Conv1D(
            2 * n_channels,
            1,
            padding="same",
            kernel_initializer=keras.initializers.RandomNormal(
                mean=0.0, stddev=0.01, seed=None
            ),
        )
        self.normalization = layers.LayerNormalization()
        self.activation = layers.Activation(activation="gelu")
        self.drop = layers.Dropout(drop_rate)

    def call(self, x):
        x = self.conv(x)
        x = self.normalization(x)
        x = self.activation(x)
        x = self.drop(x)
        return x


class ConvAttention(keras.layers.Layer):
    def __init__(self, n_channels, drop_rate=0.5):
        super().__init__()
        self.norm = layers.LayerNormalization()
        self.conv = layers.Conv1D(
            2 * n_channels,
            31,
            padding="same",
            kernel_initializer=keras.initializers.RandomNormal(
                mean=0.0, stddev=0.01, seed=None
            ),
        )
        self.activation = layers.Activation(activation="gelu")
        self.drop = layers.Dropout(drop_rate)

    def call(self, x):
        input = x
        x = self.norm(x)
        x = self.conv(x)
        x = self.activation(x)
        x = self.drop(x)
        x = x + input
        return x


class ChannelMLP(keras.layers.Layer):
    def __init__(self, n_features, drop_rate=0.5):
        super().__init__()
        self.norm = layers.LayerNormalization()
        self.mlp = layers.Dense(
            2 * n_features,
            kernel_initializer=keras.initializers.RandomNormal(
                mean=0.0, stddev=0.01, seed=None
            ),
        )
        self.activation = layers.Activation(activation="gelu")
        self.drop = layers.Dropout(drop_rate)
        self.cat = layers.Concatenate(axis=1)

    def call(self, x):
        input = x
        channels = x.shape[1]  # x shape : NCF
        x = self.norm(x)
        output_channels = []
        for i in range(channels):
            c = self.mlp(x[:, :, i])
            c = layers.Reshape([1, -1])(c)
            output_channels.append(c)
        x = self.cat(output_channels)
        x = self.activation(x)
        x = self.drop(x)
        x = x + input
        return x


class Encoder(keras.layers.Layer):
    def __init__(self, n_channels, n_features, drop_rate=0.5):
        super().__init__()
        self.attention1 = ConvAttention(n_channels, drop_rate=drop_rate)
        self.mlp1 = ChannelMLP(n_features, drop_rate=drop_rate)
        self.attention2 = ConvAttention(n_channels, drop_rate=drop_rate)
        self.mlp2 = ChannelMLP(n_features, drop_rate=drop_rate)

    def call(self, x):
        x = self.attention1(x)
        x = self.mlp1(x)
        x = self.attention2(x)
        x = self.mlp2(x)
        return x


class MlpHead(keras.layers.Layer):
    def __init__(self, n_classes, drop_rate=0.5):
        super().__init__()
        self.flatten = layers.Flatten()
        self.drop = layers.Dropout(drop_rate)
        self.linear1 = layers.Dense(
            6 * n_classes,
            kernel_initializer=keras.initializers.RandomNormal(
                mean=0.0, stddev=0.01, seed=None
            ),
        )
        self.norm = layers.LayerNormalization()
        self.activation = layers.Activation(activation="gelu")
        self.drop2 = layers.Dropout(drop_rate)
        self.linear2 = layers.Dense(
            n_classes,
            kernel_initializer=keras.initializers.RandomNormal(
                mean=0.0, stddev=0.01, seed=None
            ),
        )

    def call(self, x):
        x = self.flatten(x)
        x = self.drop(x)
        x = self.linear1(x)
        x = self.norm(x)
        x = self.activation(x)
        x = self.drop2(x)
        x = self.linear2(x)
        return x

使用上述层创建序贯模型

def create_ssvepformer(
    input_shape, fs, resolution, fq_band, n_channels, n_classes, drop_rate
):
    nfft = round(fs / resolution)
    fft_start = int(fq_band[0] / resolution)
    fft_end = int(fq_band[1] / resolution) + 1
    n_features = fft_end - fft_start

    model = keras.Sequential(
        [
            keras.Input(shape=input_shape),
            ComplexSpectrum(nfft, fft_start, fft_end),
            ChannelComb(n_channels=n_channels, drop_rate=drop_rate),
            Encoder(n_channels=n_channels, n_features=n_features, drop_rate=drop_rate),
            Encoder(n_channels=n_channels, n_features=n_features, drop_rate=drop_rate),
            MlpHead(n_classes=n_classes, drop_rate=drop_rate),
            layers.Activation(activation="softmax"),
        ]
    )

    return model

评估

# Training settings same as the original paper
BATCH_SIZE = 128
EPOCHS = 100
LR = 0.001  # learning rate
WD = 0.001  # weight decay
MOMENTUM = 0.9
DROP_RATE = 0.5

resolution = 0.25

从整个数据集中,我们为每个受试者评估选择折叠。为训练和测试数据构建 tf 数据集对象,创建模型并使用 SGD 优化器启动训练。

def concatenate_subjects(x, y, fold):
    X = np.concatenate([x[idx] for idx in fold], axis=-1)
    Y = np.concatenate([y[idx] for idx in fold], axis=-1)
    X = X.transpose((2, 1, 0))  # trials x channels x samples
    return X, Y - 1  # transform labels to values from 0...11


def evaluate_subject(
    x_train,
    y_train,
    x_val,
    y_val,
    input_shape,
    fs=256,
    resolution=0.25,
    band=[8, 64],
    channels=8,
    n_classes=12,
    drop_rate=DROP_RATE,
):

    train_dataset = (
        tf.data.Dataset.from_tensor_slices((x_train, y_train))
        .batch(BATCH_SIZE)
        .prefetch(tf.data.AUTOTUNE)
    )

    test_dataset = (
        tf.data.Dataset.from_tensor_slices((x_val, y_val))
        .batch(BATCH_SIZE)
        .prefetch(tf.data.AUTOTUNE)
    )

    model = create_ssvepformer(
        input_shape, fs, resolution, band, channels, n_classes, drop_rate
    )
    sgd = keras.optimizers.SGD(learning_rate=LR, momentum=MOMENTUM, weight_decay=WD)

    model.compile(
        loss="sparse_categorical_crossentropy",
        optimizer=sgd,
        metrics=["accuracy"],
        jit_compile=True,
    )

    history = model.fit(
        train_dataset,
        batch_size=BATCH_SIZE,
        epochs=EPOCHS,
        validation_data=test_dataset,
        verbose=0,
    )
    loss, acc = model.evaluate(test_dataset)
    return acc * 100

运行评估

channels = X.shape[2]
samples = X.shape[1]
input_shape = (channels, samples)
n_classes = 12

model = create_ssvepformer(
    input_shape, fs, resolution, band, channels, n_classes, DROP_RATE
)
model.summary()
Model: "sequential"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┓
┃ Layer (type)                          Output Shape                         Param # ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━┩
│ complex_spectrum (ComplexSpectrum)   │ (None, 8, 450)              │               0 │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ channel_comb (ChannelComb)           │ (None, 16, 450)             │           1,044 │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ encoder (Encoder)                    │ (None, 16, 450)             │          34,804 │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ encoder_1 (Encoder)                  │ (None, 16, 450)             │          34,804 │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ mlp_head (MlpHead)                   │ (None, 12)                  │         519,492 │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ activation_10 (Activation)           │ (None, 12)                  │               0 │
└──────────────────────────────────────┴─────────────────────────────┴─────────────────┘
 Total params: 590,144 (2.25 MB)
 Trainable params: 590,144 (2.25 MB)
 Non-trainable params: 0 (0.00 B)

按照留一受试者法数据划分方案对所有受试者进行评估

accs = np.zeros(10)

for subject in range(10):
    print(f"Testing subject: {subject+ 1}")

    # create train / test folds
    folds = np.delete(np.arange(10), subject)
    train_index = folds
    test_index = [subject]

    # create data split for each subject
    x_train, y_train = concatenate_subjects(X, Y, train_index)
    x_val, y_val = concatenate_subjects(X, Y, test_index)

    # train and evaluate a fold and compute the time it takes
    acc = evaluate_subject(x_train, y_train, x_val, y_val, input_shape)

    accs[subject] = acc

print(f"\nAccuracy Across Subjects: {accs.mean()} % std: {np.std(accs)}")
Testing subject: 1

WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1737801392.665434    1425 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1737801393.156241    1425 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1737801393.156492    1425 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1737801393.160002    1425 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1737801393.160279    1425 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1737801393.160421    1425 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1737801393.168480    1425 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1737801393.168667    1425 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1737801393.168908    1425 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355

1/2 ━━━━━━━━━━ [37m━━━━━━━━━━ 0s 7ms/step - accuracy: 0.5938 - loss: 1.2483


2/2 ━━━━━━━━━━━━━━━━━━━━ 0s 39ms/step - accuracy: 0.5868 - loss: 1.3010

Testing subject: 2

1/2 ━━━━━━━━━━[37m━━━━━━━━━━ 0s 7ms/step - accuracy: 0.5469 - loss: 1.5270


2/2 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - accuracy: 0.5119 - loss: 1.5974

Testing subject: 3

1/2 ━━━━━━━━━━[37m━━━━━━━━━━ 0s 7ms/step - accuracy: 0.7266 - loss: 0.8867


2/2 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - accuracy: 0.6903 - loss: 1.0117

Testing subject: 4

1/2 ━━━━━━━━━━[37m━━━━━━━━━━ 0s 7ms/step - accuracy: 0.9688 - loss: 0.1574


2/2 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - accuracy: 0.9637 - loss: 0.1487

Testing subject: 5

1/2 ━━━━━━━━━━[37m━━━━━━━━━━ 0s 7ms/step - accuracy: 0.9375 - loss: 0.2184


2/2 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - accuracy: 0.9458 - loss: 0.1887

Testing subject: 6

1/2 ━━━━━━━━━━[37m━━━━━━━━━━ 0s 12ms/step - accuracy: 0.9688 - loss: 0.1117


2/2 ━━━━━━━━━━━━━━━━━━━━ 0s 10ms/step - accuracy: 0.9674 - loss: 0.1018

Testing subject: 7

1/2 ━━━━━━━━━━[37m━━━━━━━━━━ 0s 7ms/step - accuracy: 0.9141 - loss: 0.2639


2/2 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - accuracy: 0.9158 - loss: 0.2592

Testing subject: 8

1/2 ━━━━━━━━━━[37m━━━━━━━━━━ 0s 9ms/step - accuracy: 0.9922 - loss: 0.0562


2/2 ━━━━━━━━━━━━━━━━━━━━ 0s 20ms/step - accuracy: 0.9937 - loss: 0.0518

Testing subject: 9

1/2 ━━━━━━━━━━[37m━━━━━━━━━━ 0s 7ms/step - accuracy: 0.9844 - loss: 0.0669


2/2 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - accuracy: 0.9837 - loss: 0.0701

Testing subject: 10

1/2 ━━━━━━━━━━[37m━━━━━━━━━━ 0s 10ms/step - accuracy: 0.9219 - loss: 0.3438


2/2 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - accuracy: 0.8999 - loss: 0.4543

Accuracy Across Subjects: 84.11111384630203 % std: 17.575586372993953

就这样!我们看到一些在训练集中没有数据的受试者仍然可以达到接近 100% 的正确指令,而另一些受试者表现较差,准确率约为 50%。在原始论文中使用 PyTorch 时,平均准确率为 84.04%,标准差为 17.37。考虑到深度学习的随机性,我们达到了相同的值。

可视化

不同后端(Jax、Tensorflow 和 PyTorch)在 Colab Free/Pro/Pro+ 上可用的三种 GPU(T4、L4、A100)上的训练和推理时间比较。


训练时间

training_time

推理时间

inference_time

Jax 后端在所有 GPU 上的训练和推理性能最佳,而 PyTorch 由于 FFT 计算出的复杂数据类型不受 PyTorch JIT 编译器支持,导致 JIT 编译选项被禁用,因此速度极慢。

致谢

感谢 Chris Perry X @GoogleColab 为本工作提供 GPU 计算支持。

参考文献

[1] Chen, J. et al. (2023) ‘一种基于 Transformer 的深度神经网络模型用于 SSVEP 分类’, Neural Networks, 164, pp. 521–534. 可用地址: https://doi.org/10.1016/j.neunet.2023.04.045.

[2] Nakanishi, M. et al. (2015) ‘基于典型相关分析方法的稳态视觉诱发电位检测比较研究’, Plos One, 10(10), p. e0140703. 可用地址: https://doi.org/10.1371/journal.pone.0140703