MultiHeadAttention 类keras.layers.MultiHeadAttention(
num_heads,
key_dim,
value_dim=None,
dropout=0.0,
use_bias=True,
output_shape=None,
attention_axes=None,
flash_attention=None,
kernel_initializer="glorot_uniform",
bias_initializer="zeros",
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
seed=None,
**kwargs
)
MultiHeadAttention 层。
这是根据论文 "Attention is all you Need" Vaswani et al., 2017 实现的多头注意力。如果 query、key、value 相同,则为自注意力。query 中的每个时间步都关注 key 中的相应序列,并返回一个固定宽度的向量。
该层首先对 query、key 和 value 进行投影。这些(有效地)是由 num_attention_heads 长度的张量列表组成的,其中相应的形状为 (batch_size, <query dimensions>, key_dim)、(batch_size, <key/value dimensions>, key_dim)、(batch_size, <key/value dimensions>, value_dim)。
然后,查询和键张量进行点积和缩放。然后对它们进行 softmax 操作以获得注意力概率。然后,值张量通过这些概率进行插值,然后连接回单个张量。
最后,最后一个维度为 value_dim 的结果张量可以进行线性投影并返回。
参数
None 表示在除 batch、heads 和 features 之外的所有轴上应用注意力。None,则该层尝试使用 flash attention 来进行更快、更节省内存的注意力计算(如果可能)。此行为可以通过 keras.config.enable_flash_attention() 或 keras.config.disable_flash_attention() 进行配置。调用参数
(B, T, dim),其中 B 是 batch 大小,T 是目标序列长度,dim 是特征维度。(B, S, dim),其中 B 是 batch 大小,S 是源序列长度,dim 是特征维度。(B, S, dim)。如果未给出,则将使用 value 作为 key 和 value,这是最常见的情况。(B, T, S) 的布尔掩码,用于阻止注意力指向特定位置。布尔掩码指定了哪些查询元素可以关注哪些键元素,1 表示允许注意力,0 表示不允许注意力。对于缺失的 batch 维度和 head 维度,可以进行广播。True,输出应为 (attention_output, attention_scores),如果为 False,则为 attention_output。默认为 False。False(推理)。返回
(B, T, E),其中 T 是目标序列形状,E 是查询输入的最后一个维度(如果 output_shape 为 None)。否则,多头注意力输出将被投影到 output_shape 指定的形状。