作者: Okba Bekhelifi
创建日期 2025/01/08
最后修改日期 2025/01/08
描述: 基于 Transformer 的 BCI 脑电信号分类。
本教程将解释如何构建基于 Transformer 的神经网络,用于分类稳态视觉诱发电位 (SSVEP) 实验中记录的脑机接口 (BCI) 脑电图 (EEG) 数据,以应用于脑控拼写器。
本教程重现了 SSVEPFormer 研究 [1] 的一项实验(arXiv 预印本 / 同行评审论文)。该模型是首个被引入用于 SSVEP 数据分类的基于 Transformer 的模型,我们将在 Nakanishi 等人 [2] 的公开数据集(论文中的数据集 1)上对其进行测试。
该过程遵循跨受试者分类实验。给定数据集中 N 个受试者的数据,训练数据分区包含来自 N-1 个受试者的数据,剩余的单个受试者数据用于测试。训练集不包含来自测试受试者的任何样本。通过这种方式,我们构建了一个真正的受试者独立模型。我们在从预处理到训练的所有处理操作中都保持与原始论文相同的参数和设置。
本教程首先简要介绍 BCI 和数据集,然后我们将详细介绍以下部分的技术细节: - 设置和导入。- 数据集下载和提取。- 数据预处理:脑电数据滤波、分割以及原始和滤波数据的可视化,以及表现良好的参与者的频率响应。- 层和模型创建。- 评估:单个参与者数据分类示例,然后是所有参与者数据分类。- 可视化:我们展示了在三种不同 GPU 上(JAX、Tensorflow 和 PyTorch)在 Keras 3 可用后端之间进行训练和推理时间比较的结果。- 结论:最终讨论和评论。
BCI 提供了仅使用大脑活动进行交流的能力,这可以通过产生特定反应的外源性刺激来实现,这些反应表明了受试者的意图。当用户将注意力集中在目标刺激上时,就会引发这些反应。我们可以使用视觉刺激,通过在显示器上以网格形式向受试者呈现一组选项,以便一次选择一个命令。每个刺激将以固定的频率和相位闪烁,在皮质的枕叶和枕顶叶区域(视觉皮层)记录的脑电图将在与受试者注视的刺激相关的频率下具有更高的功率。这种类型的 BCI 范式称为稳态视觉诱发电位 (SSVEP),由于其可靠性和高分类性能以及快速性(1 秒的脑电图就足以发出命令),因此被广泛用于多种应用。还存在其他类型的脑反应,它们不需要外部刺激,但它们的可靠性较低。演示视频
本教程使用包含 12 个命令(类别)的公开 SSVEP 数据集 [2],其界面模拟电话拨号数字。
该数据集记录了 10 名参与者的数据,每位参与者都面对上述 12 个 SSVEP 刺激 (A)。刺激频率范围为 9.25Hz 至 14.75 Hz,步长为 0.5Hz,每行的相位范围为 0 至 1.5 π,步长为 0.5 π (B)。脑电信号通过 8 个电极(通道)(PO7、PO3、POz、PO4、PO8、O1、Oz、O2)采集,采样频率为 2048 Hz,然后将存储的数据下采样至 256 Hz。受试者完成了 15 个记录块,每个块包含 12 个随机排序的刺激(每个类别 1 个),每个刺激持续 4 秒。总共,每位受试者进行了 180 次试验。
import os
os.environ["KERAS_BACKEND"] = "jax"
!pip install -q numpy
!pip install -q scipy
!pip install -q matplotlib
# deep learning libraries
from keras import backend as K
from keras import layers
import keras
# visualization and signal processing imports
import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np
from scipy.signal import butter, filtfilt
from scipy.io import loadmat
# setting the backend, seed and Keras channel format
K.set_image_data_format("channels_first")
keras.utils.set_random_seed(42)
!curl -O https://sccn.ucsd.edu/download/cca_ssvep.zip
!unzip cca_ssvep.zip
% Total % Received % Xferd Average Speed Time Time Time Current
Dload Upload Total Spent Left Speed
0 0 0 0 0 0 0 0 –:–:– –:–:– –:–:– 0 0 0 0 0 0 0 0 0 –:–:– –:–:– –:–:– 0
0 145M 0 49152 0 0 40897 0 1:02:11 0:00:01 1:02:10 40891
1 145M 1 2480k 0 0 1140k 0 0:02:10 0:00:02 0:02:08 1140k
9 145M 9 13.4M 0 0 4371k 0 0:00:34 0:00:03 0:00:31 4371k
17 145M 17 25.9M 0 0 6201k 0 0:00:24 0:00:04 0:00:20 6200k
25 145M 25 37.2M 0 0 7398k 0 0:00:20 0:00:05 0:00:15 7632k
33 145M 33 48.9M 0 0 8052k 0 0:00:18 0:00:06 0:00:12 9972k
41 145M 41 60.3M 0 0 8586k 0 0:00:17 0:00:07 0:00:10 11.5M
49 145M 49 71.9M 0 0 9014k 0 0:00:16 0:00:08 0:00:08 11.6M
57 145M 57 84.2M 0 0 9423k 0 0:00:15 0:00:09 0:00:06 11.9M
66 145M 66 96.5M 0 0 9709k 0 0:00:15 0:00:10 0:00:05 11.8M
74 145M 74 108M 0 0 9938k 0 0:00:14 0:00:11 0:00:03 12.0M
82 145M 82 119M 0 0 9.8M 0 0:00:14 0:00:12 0:00:02 11.8M
90 145M 90 131M 0 0 9.9M 0 0:00:14 0:00:13 0:00:01 11.8M
98 145M 98 142M 0 0 10.0M 0 0:00:14 0:00:14 –:–:– 11.7M 100 145M 100 145M 0 0 10.1M 0 0:00:14 0:00:14 –:–:– 11.7M
Archive: cca_ssvep.zip
creating: cca_ssvep/
inflating: cca_ssvep/s4.mat
inflating: cca_ssvep/s5.mat
inflating: cca_ssvep/s3.mat
inflating: cca_ssvep/s7.mat
inflating: cca_ssvep/chan_locs.pdf
inflating: cca_ssvep/readme.txt
inflating: cca_ssvep/s2.mat
inflating: cca_ssvep/s8.mat
inflating: cca_ssvep/s10.mat
inflating: cca_ssvep/s9.mat
inflating: cca_ssvep/s6.mat
inflating: cca_ssvep/s1.mat
预处理步骤首先是读取每个受试者的脑电数据,然后在最有用信息所在的频率区间内对原始数据进行滤波,然后我们选择从刺激开始的固定持续时间的信号(由于视觉系统引起的延迟,我们在刺激开始时添加 135 毫秒)。最后,所有受试者的数据都连接到一个形状为 [受试者 x 样本 x 通道 x 试验] 的单个张量中。数据标签也按照实验中试验的顺序连接,并将是一个形状为 [受试者 x 试验] 的矩阵(这里通道指的是电极,我们在整个教程中使用此表示法)。
def raw_signal(folder, fs=256, duration=1.0, onset=0.135):
"""selecting a 1-second segment of the raw EEG signal for
subject 1.
"""
onset = 38 + int(onset * fs)
end = int(duration * fs)
data = loadmat(f"{folder}/s1.mat")
# samples, channels, trials, targets
eeg = data["eeg"].transpose((2, 1, 3, 0))
# segment data
eeg = eeg[onset : onset + end, :, :, :]
return eeg
def segment_eeg(
folder, elecs=None, fs=256, duration=1.0, band=[5.0, 45.0], order=4, onset=0.135
):
"""Filtering and segmenting EEG signals for all subjects."""
n_subejects = 10
onset = 38 + int(onset * fs)
end = int(duration * fs)
X, Y = [], [] # empty data and labels
for subj in range(1, n_subejects + 1):
data = loadmat(f"{data_folder}/s{subj}.mat")
# samples, channels, trials, targets
eeg = data["eeg"].transpose((2, 1, 3, 0))
# filter data
eeg = filter_eeg(eeg, fs=fs, band=band, order=order)
# segment data
eeg = eeg[onset : onset + end, :, :, :]
# reshape labels
samples, channels, blocks, targets = eeg.shape
y = np.tile(np.arange(1, targets + 1), (blocks, 1))
y = y.reshape((1, blocks * targets), order="F")
X.append(eeg.reshape((samples, channels, blocks * targets), order="F"))
Y.append(y)
X = np.array(X, dtype=np.float32, order="F")
Y = np.array(Y, dtype=np.float32).squeeze()
return X, Y
def filter_eeg(data, fs=256, band=[5.0, 45.0], order=4):
"""Filter EEG signal using a zero-phase IIR filter"""
B, A = butter(order, np.array(band) / (fs / 2), btype="bandpass")
return filtfilt(B, A, data, axis=0)
data_folder = os.path.abspath("./cca_ssvep")
band = [8, 64] # low-frequency / high-frequency cutoffS
order = 4 # filter order
fs = 256 # sampling frequency
duration = 1.0 # 1 second
# raw signal
X_raw = raw_signal(data_folder, fs=fs, duration=duration)
print(
f"A single subject raw EEG (X_raw) shape: {X_raw.shape} [Samples x Channels x Blocks x Targets]"
)
# segmented signal
X, Y = segment_eeg(data_folder, band=band, order=order, fs=fs, duration=duration)
print(
f"Full training data (X) shape: {X.shape} [Subject x Samples x Channels x Trials]"
)
print(f"data labels (Y) shape: {Y.shape} [Subject x Trials]")
samples = X.shape[1]
time = np.linspace(0.0, samples / fs, samples) * 1000
A single subject raw EEG (X_raw) shape: (256, 8, 15, 12) [Samples x Channels x Blocks x Targets]
Full training data (X) shape: (10, 256, 8, 180) [Subject x Samples x Channels x Trials]
data labels (Y) shape: (10, 180) [Subject x Trials]
原始脑电图与滤波脑电图。图中展示了受试者 s1 在 Oz(视觉皮层的中心电极,头部后部)记录的相同 1 秒记录。左图是记录的原始脑电图,右图是 [8, 64] Hz 频带上的滤波脑电图。我们看到噪声更少,并且在自然脑电范围内的归一化幅度值。
elec = 6 # Oz channel
x_label = "Time (ms)"
y_label = "Voltage (uV)"
# Create subplots
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4))
# Plot data on the first subplot
ax1.plot(time, X_raw[:, elec, 0, 0], "r-")
ax1.set_xlabel(x_label)
ax1.set_ylabel(y_label)
ax1.set_title("Raw EEG : 1 second at Oz ")
# Plot data on the second subplot
ax2.plot(time, X[0, :, elec, 0], "b-")
ax2.set_xlabel(x_label)
ax2.set_ylabel(y_label)
ax2.set_title("Filtered EEG between 8-64 Hz: 1 second at Oz")
# Adjust spacing between subplots
plt.tight_layout()
# Show the plot
plt.show()
使用 Welch 方法,我们可视化了表现良好的受试者在 Oz 电极上针对每个刺激的整个 4 秒脑电记录的频率功率。红色峰值表示刺激基频和二次谐波(基频的两倍)。我们看到清晰的峰值,表明该受试者的高响应,这意味着该受试者是 SSVEP BCI 控制的良好候选者。在许多情况下,峰值很弱或不存在,这意味着受试者没有正确完成任务。
以跨框架自定义组件方式创建层。在 SSVEPFormer 中,数据首先通过快速傅里叶变换 (FFT) 变换到频域,以构建一个复频谱表示,该表示由固定频带中频率和相位信息的连接组成。为了保持模型的端到端格式,我们将复频谱变换实现为非可训练层。
与 Transformer 架构不同,SSVEPFormer 不包含位置编码/嵌入层,而是替换了通道组合块,该块具有一个卷积核大小为 1 的 Conv1D 层,滤波器数量为输入通道(电极数量的两倍)的两倍,以及 LayerNorm、Gelu 激活和 dropout。与 Transformer 的另一个区别是没有多头注意力层和注意力机制。模型编码器包含两个相同的连续块。每个块有两个子块:CNN 模块和 MLP 模块。CNN 模块由 LayerNorm、Conv1D(滤波器数量与通道组合相同)、LayerNorm、Gelu、Dropout 和残差连接组成。MLP 模块由 LayerNorm、Dense 层、Gelu、dropout 和残差连接组成。Dense 层分别应用于每个通道。模型的最后一个块是 MLP 头,包含 Flatten 层、Dropout、Dense、LayerNorm、Gelu、Dropout 和带有 softmax 激活的 Dense 层。所有可训练权重都通过均值为 0,标准差为 0.01 的正态分布初始化,这与原始论文中的状态相同。
class ComplexSpectrum(keras.layers.Layer):
def __init__(self, nfft=512, fft_start=8, fft_end=64):
super().__init__()
self.nfft = nfft
self.fft_start = fft_start
self.fft_end = fft_end
def call(self, x):
samples = x.shape[-1]
x = keras.ops.rfft(x, fft_length=self.nfft)
real = x[0] / samples
imag = x[1] / samples
real = real[:, :, self.fft_start : self.fft_end]
imag = imag[:, :, self.fft_start : self.fft_end]
x = keras.ops.concatenate((real, imag), axis=-1)
return x
class ChannelComb(keras.layers.Layer):
def __init__(self, n_channels, drop_rate=0.5):
super().__init__()
self.conv = layers.Conv1D(
2 * n_channels,
1,
padding="same",
kernel_initializer=keras.initializers.RandomNormal(
mean=0.0, stddev=0.01, seed=None
),
)
self.normalization = layers.LayerNormalization()
self.activation = layers.Activation(activation="gelu")
self.drop = layers.Dropout(drop_rate)
def call(self, x):
x = self.conv(x)
x = self.normalization(x)
x = self.activation(x)
x = self.drop(x)
return x
class ConvAttention(keras.layers.Layer):
def __init__(self, n_channels, drop_rate=0.5):
super().__init__()
self.norm = layers.LayerNormalization()
self.conv = layers.Conv1D(
2 * n_channels,
31,
padding="same",
kernel_initializer=keras.initializers.RandomNormal(
mean=0.0, stddev=0.01, seed=None
),
)
self.activation = layers.Activation(activation="gelu")
self.drop = layers.Dropout(drop_rate)
def call(self, x):
input = x
x = self.norm(x)
x = self.conv(x)
x = self.activation(x)
x = self.drop(x)
x = x + input
return x
class ChannelMLP(keras.layers.Layer):
def __init__(self, n_features, drop_rate=0.5):
super().__init__()
self.norm = layers.LayerNormalization()
self.mlp = layers.Dense(
2 * n_features,
kernel_initializer=keras.initializers.RandomNormal(
mean=0.0, stddev=0.01, seed=None
),
)
self.activation = layers.Activation(activation="gelu")
self.drop = layers.Dropout(drop_rate)
self.cat = layers.Concatenate(axis=1)
def call(self, x):
input = x
channels = x.shape[1] # x shape : NCF
x = self.norm(x)
output_channels = []
for i in range(channels):
c = self.mlp(x[:, :, i])
c = layers.Reshape([1, -1])(c)
output_channels.append(c)
x = self.cat(output_channels)
x = self.activation(x)
x = self.drop(x)
x = x + input
return x
class Encoder(keras.layers.Layer):
def __init__(self, n_channels, n_features, drop_rate=0.5):
super().__init__()
self.attention1 = ConvAttention(n_channels, drop_rate=drop_rate)
self.mlp1 = ChannelMLP(n_features, drop_rate=drop_rate)
self.attention2 = ConvAttention(n_channels, drop_rate=drop_rate)
self.mlp2 = ChannelMLP(n_features, drop_rate=drop_rate)
def call(self, x):
x = self.attention1(x)
x = self.mlp1(x)
x = self.attention2(x)
x = self.mlp2(x)
return x
class MlpHead(keras.layers.Layer):
def __init__(self, n_classes, drop_rate=0.5):
super().__init__()
self.flatten = layers.Flatten()
self.drop = layers.Dropout(drop_rate)
self.linear1 = layers.Dense(
6 * n_classes,
kernel_initializer=keras.initializers.RandomNormal(
mean=0.0, stddev=0.01, seed=None
),
)
self.norm = layers.LayerNormalization()
self.activation = layers.Activation(activation="gelu")
self.drop2 = layers.Dropout(drop_rate)
self.linear2 = layers.Dense(
n_classes,
kernel_initializer=keras.initializers.RandomNormal(
mean=0.0, stddev=0.01, seed=None
),
)
def call(self, x):
x = self.flatten(x)
x = self.drop(x)
x = self.linear1(x)
x = self.norm(x)
x = self.activation(x)
x = self.drop2(x)
x = self.linear2(x)
return x
def create_ssvepformer(
input_shape, fs, resolution, fq_band, n_channels, n_classes, drop_rate
):
nfft = round(fs / resolution)
fft_start = int(fq_band[0] / resolution)
fft_end = int(fq_band[1] / resolution) + 1
n_features = fft_end - fft_start
model = keras.Sequential(
[
keras.Input(shape=input_shape),
ComplexSpectrum(nfft, fft_start, fft_end),
ChannelComb(n_channels=n_channels, drop_rate=drop_rate),
Encoder(n_channels=n_channels, n_features=n_features, drop_rate=drop_rate),
Encoder(n_channels=n_channels, n_features=n_features, drop_rate=drop_rate),
MlpHead(n_classes=n_classes, drop_rate=drop_rate),
layers.Activation(activation="softmax"),
]
)
return model
# Training settings same as the original paper
BATCH_SIZE = 128
EPOCHS = 100
LR = 0.001 # learning rate
WD = 0.001 # weight decay
MOMENTUM = 0.9
DROP_RATE = 0.5
resolution = 0.25
从整个数据集中,我们为每个受试者评估选择折叠。构建用于训练和测试数据的 tf 数据集对象,创建模型并使用 SGD 优化器启动训练。
def concatenate_subjects(x, y, fold):
X = np.concatenate([x[idx] for idx in fold], axis=-1)
Y = np.concatenate([y[idx] for idx in fold], axis=-1)
X = X.transpose((2, 1, 0)) # trials x channels x samples
return X, Y - 1 # transform labels to values from 0...11
def evaluate_subject(
x_train,
y_train,
x_val,
y_val,
input_shape,
fs=256,
resolution=0.25,
band=[8, 64],
channels=8,
n_classes=12,
drop_rate=DROP_RATE,
):
train_dataset = (
tf.data.Dataset.from_tensor_slices((x_train, y_train))
.batch(BATCH_SIZE)
.prefetch(tf.data.AUTOTUNE)
)
test_dataset = (
tf.data.Dataset.from_tensor_slices((x_val, y_val))
.batch(BATCH_SIZE)
.prefetch(tf.data.AUTOTUNE)
)
model = create_ssvepformer(
input_shape, fs, resolution, band, channels, n_classes, drop_rate
)
sgd = keras.optimizers.SGD(learning_rate=LR, momentum=MOMENTUM, weight_decay=WD)
model.compile(
loss="sparse_categorical_crossentropy",
optimizer=sgd,
metrics=["accuracy"],
jit_compile=True,
)
history = model.fit(
train_dataset,
batch_size=BATCH_SIZE,
epochs=EPOCHS,
validation_data=test_dataset,
verbose=0,
)
loss, acc = model.evaluate(test_dataset)
return acc * 100
channels = X.shape[2]
samples = X.shape[1]
input_shape = (channels, samples)
n_classes = 12
model = create_ssvepformer(
input_shape, fs, resolution, band, channels, n_classes, DROP_RATE
)
model.summary()
Model: "sequential"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━┩ │ complex_spectrum (ComplexSpectrum) │ (None, 8, 450) │ 0 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ channel_comb (ChannelComb) │ (None, 16, 450) │ 1,044 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ encoder (Encoder) │ (None, 16, 450) │ 34,804 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ encoder_1 (Encoder) │ (None, 16, 450) │ 34,804 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ mlp_head (MlpHead) │ (None, 12) │ 519,492 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ activation_10 (Activation) │ (None, 12) │ 0 │ └──────────────────────────────────────┴─────────────────────────────┴─────────────────┘
Total params: 590,144 (2.25 MB)
Trainable params: 590,144 (2.25 MB)
Non-trainable params: 0 (0.00 B)
accs = np.zeros(10)
for subject in range(10):
print(f"Testing subject: {subject+ 1}")
# create train / test folds
folds = np.delete(np.arange(10), subject)
train_index = folds
test_index = [subject]
# create data split for each subject
x_train, y_train = concatenate_subjects(X, Y, train_index)
x_val, y_val = concatenate_subjects(X, Y, test_index)
# train and evaluate a fold and compute the time it takes
acc = evaluate_subject(x_train, y_train, x_val, y_val, input_shape)
accs[subject] = acc
print(f"\nAccuracy Across Subjects: {accs.mean()} % std: {np.std(accs)}")
Testing subject: 1
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1737801392.665434 1425 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1737801393.156241 1425 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1737801393.156492 1425 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1737801393.160002 1425 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1737801393.160279 1425 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1737801393.160421 1425 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1737801393.168480 1425 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1737801393.168667 1425 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1737801393.168908 1425 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
1/2 ━━━━━━━━━━ [37m━━━━━━━━━━ 0 秒 7 毫秒/步 - 准确率:0.5938 - 损失:1.2483
2/2 ━━━━━━━━━━━━━━━━━━━━ 0 秒 39 毫秒/步 - 准确率:0.5868 - 损失:1.3010
Testing subject: 2
1/2 ━━━━━━━━━━[37m━━━━━━━━━━ 0 秒 7 毫秒/步 - 准确率:0.5469 - 损失:1.5270
2/2 ━━━━━━━━━━━━━━━━━━━━ 0 秒 7 毫秒/步 - 准确率:0.5119 - 损失:1.5974
Testing subject: 3
1/2 ━━━━━━━━━━[37m━━━━━━━━━━ 0 秒 7 毫秒/步 - 准确率:0.7266 - 损失:0.8867
2/2 ━━━━━━━━━━━━━━━━━━━━ 0 秒 7 毫秒/步 - 准确率:0.6903 - 损失:1.0117
Testing subject: 4
1/2 ━━━━━━━━━━[37m━━━━━━━━━━ 0 秒 7 毫秒/步 - 准确率:0.9688 - 损失:0.1574
2/2 ━━━━━━━━━━━━━━━━━━━━ 0 秒 8 毫秒/步 - 准确率:0.9637 - 损失:0.1487
Testing subject: 5
1/2 ━━━━━━━━━━[37m━━━━━━━━━━ 0 秒 7 毫秒/步 - 准确率:0.9375 - 损失:0.2184
2/2 ━━━━━━━━━━━━━━━━━━━━ 0 秒 7 毫秒/步 - 准确率:0.9458 - 损失:0.1887
Testing subject: 6
1/2 ━━━━━━━━━━[37m━━━━━━━━━━ 0 秒 12 毫秒/步 - 准确率:0.9688 - 损失:0.1117
2/2 ━━━━━━━━━━━━━━━━━━━━ 0 秒 10 毫秒/步 - 准确率:0.9674 - 损失:0.1018
Testing subject: 7
1/2 ━━━━━━━━━━[37m━━━━━━━━━━ 0 秒 7 毫秒/步 - 准确率:0.9141 - 损失:0.2639
2/2 ━━━━━━━━━━━━━━━━━━━━ 0 秒 7 毫秒/步 - 准确率:0.9158 - 损失:0.2592
Testing subject: 8
1/2 ━━━━━━━━━━[37m━━━━━━━━━━ 0 秒 9 毫秒/步 - 准确率:0.9922 - 损失:0.0562
2/2 ━━━━━━━━━━━━━━━━━━━━ 0 秒 20 毫秒/步 - 准确率:0.9937 - 损失:0.0518
Testing subject: 9
1/2 ━━━━━━━━━━[37m━━━━━━━━━━ 0 秒 7 毫秒/步 - 准确率:0.9844 - 损失:0.0669
2/2 ━━━━━━━━━━━━━━━━━━━━ 0 秒 8 毫秒/步 - 准确率:0.9837 - 损失:0.0701
Testing subject: 10
1/2 ━━━━━━━━━━[37m━━━━━━━━━━ 0 秒 10 毫秒/步 - 准确率:0.9219 - 损失:0.3438
2/2 ━━━━━━━━━━━━━━━━━━━━ 0 秒 18 毫秒/步 - 准确率:0.8999 - 损失:0.4543
Accuracy Across Subjects: 84.11111384630203 % std: 17.575586372993953
就这样!我们看到,一些训练集中没有数据的受试者仍然可以达到接近 100% 的命令正确率,而另一些受试者的表现较差,准确率约为 50%。在原始论文中使用 PyTorch 时,平均准确率为 84.04%,标准差为 17.37%。考虑到深度学习的随机性,我们达到了相同的值。
在 Colab Free/Pro/Pro+ 提供的三种 GPU(T4、L4、A100)上,比较不同后端(Jax、Tensorflow 和 PyTorch)之间的训练和推理时间。
在所有 GPU 上,Jax 后端在训练和推理方面都是最好的,PyTorch 速度极慢,这是因为禁用了 jit 编译选项,原因是 FFT 计算的复杂数据类型 PyTorch jit 编译器不支持。
感谢 Chris Perry X @GoogleColab 为这项工作提供 GPU 计算支持。
[1] Chen, J. 等人。(2023) ‘用于 SSVEP 分类的基于 Transformer 的深度神经网络模型’,《Neural Networks》,164,第 521–534 页。可从以下网址获取:https://doi.org/10.1016/j.neunet.2023.04.045。
[2] Nakanishi, M. 等人。(2015) ‘用于检测稳态视觉诱发电位的基于典型相关分析方法的比较研究’,《Plos One》,10(10),第 e0140703 页。可从以下网址获取:https://doi.org/10.1371/journal.pone.0140703