MultiHeadAttention
类tf_keras.layers.MultiHeadAttention(
num_heads,
key_dim,
value_dim=None,
dropout=0.0,
use_bias=True,
output_shape=None,
attention_axes=None,
kernel_initializer="glorot_uniform",
bias_initializer="zeros",
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
**kwargs
)
MultiHeadAttention 层。
这是对论文“Attention is all you Need”(Vaswani et al., 2017) 中描述的多头注意力机制的实现。如果 query
, key,
和 value
相同,则这是自注意力机制。query
中的每个时间步都会关注 key
中相应的序列,并返回一个固定宽度的向量。
该层首先对 query
, key
和 value
进行投影。这些(实际上)是一个长度为 num_attention_heads
的张量列表,其对应形状分别为 (batch_size, <query dimensions>, key_dim)
, (batch_size, <key/value dimensions>, key_dim)
, (batch_size, <key/value dimensions>, value_dim)
。
然后,对 query 和 key 张量进行点乘并缩放。然后对其进行 softmax 操作以获得注意力概率。然后,value 张量根据这些概率进行插值,再拼接回一个单个张量。
最后,最后一个维度为 value_dim 的结果张量可以进行线性投影并输出。
在自定义层中使用 MultiHeadAttention
时,自定义层必须实现自己的 build()
方法,并在其中调用 MultiHeadAttention
的 _build_from_signature()
方法。这使得模型加载时可以正确恢复权重。
示例
使用注意力掩码对两个序列输入执行 1D 交叉注意力。返回每个头的额外注意力权重。
>>> layer = MultiHeadAttention(num_heads=2, key_dim=2)
>>> target = tf.keras.Input(shape=[8, 16])
>>> source = tf.keras.Input(shape=[4, 16])
>>> output_tensor, weights = layer(target, source,
... return_attention_scores=True)
>>> print(output_tensor.shape)
(None, 8, 16)
>>> print(weights.shape)
(None, 2, 8, 4)
对一个 5D 输入张量在其轴 2 和轴 3 上执行 2D 自注意力。
>>> layer = MultiHeadAttention(
... num_heads=2, key_dim=2, attention_axes=(2, 3))
>>> input_tensor = tf.keras.Input(shape=[5, 3, 4, 16])
>>> output_tensor = layer(input_tensor, input_tensor)
>>> print(output_tensor.shape)
(None, 5, 3, 4, 16)
参数
None
表示对除批次、头和特征之外的所有轴应用注意力。调用参数
(B, T, dim)
的 Query Tensor
。(B, S, dim)
的 Value Tensor
。Tensor
,形状为 (B, S, dim)
。如果未提供,则将 value
用于 key
和 value
,这是最常见的情况。(B, T, S)
的布尔掩码,用于阻止对某些位置的注意力。布尔掩码指定了哪些 query 元素可以关注哪些 key 元素,1 表示关注,0 表示不关注。对于缺失的批次维度和头维度,可以进行广播。True
,则输出应为 (attention_output, attention_scores)
,如果为 False
,则输出为 attention_output
。默认为 False
。False
(推断模式)。返回值
(B, T, E)
,其中 T
表示目标序列形状,如果 output_shape
为 None
,则 E
为 query 输入的最后一个维度。否则,多头输出会投影到由 output_shape
指定的形状。