MultiHeadAttention
类keras.layers.MultiHeadAttention(
num_heads,
key_dim,
value_dim=None,
dropout=0.0,
use_bias=True,
output_shape=None,
attention_axes=None,
flash_attention=None,
kernel_initializer="glorot_uniform",
bias_initializer="zeros",
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
seed=None,
**kwargs
)
MultiHeadAttention 层。
这是论文“Attention is all you Need”Vaswani 等人,2017 中描述的多头注意力(multi-headed attention)的实现。如果 query
、key
、value
相同,则这是自注意力(self-attention)。query
中的每个时间步都会关注 key
中对应的序列,并返回一个固定宽度的向量。
此层首先对 query
、key
和 value
进行投影。这些(实际上)是长度为 num_attention_heads
的张量列表,其对应形状分别为 (batch_size, <query dimensions>, key_dim)
、(batch_size, <key/value dimensions>, key_dim)
、(batch_size, <key/value dimensions>, value_dim)
。
然后,对 query 和 key 张量进行点积并缩放。这些张量经过 softmax 处理以获得注意力概率。然后,值张量由这些概率进行插值,再拼接回单个张量。
最后,最后一维为 value_dim
的结果张量可以进行线性投影并返回。
参数
None
表示在除了批次、头和特征之外的所有轴上应用注意力。None
,则该层会在可能时尝试使用 Flash Attention 以实现更快、更节省内存的注意力计算。可以使用 keras.config.enable_flash_attention()
或 keras.config.disable_flash_attention()
配置此行为。调用参数
(B, T, dim)
的 query 张量,其中 B
是批次大小,T
是目标序列长度,dim 是特征维度。(B, S, dim)
的 value 张量,其中 B
是批次大小,S
是源序列长度,dim 是特征维度。(B, S, dim)
的可选 key 张量。如果未给出,则将 value
同时用于 key
和 value
,这是最常见的情况。(B, T, S)
的布尔掩码,用于阻止对某些位置的注意力。布尔掩码指定了哪些 query 元素可以关注哪些 key 元素,1 表示关注,0 表示不关注。对于缺失的批次维度和头维度,可能会发生广播。True
,输出是否应为 (attention_output, attention_scores)
;如果为 False
,则仅为 attention_output
。默认为 False
。False
(推理模式)。返回
(B, T, E)
,其中 T
表示目标序列形状,如果 output_shape
为 None
,则 E
是 query 输入的最后一个维度。否则,多头输出将被投影到由 output_shape
指定的形状。