Keras 2 API 文档 / 层 API / 注意力层 / MultiHeadAttention 层

MultiHeadAttention 层

[源]

MultiHeadAttention

tf_keras.layers.MultiHeadAttention(
    num_heads,
    key_dim,
    value_dim=None,
    dropout=0.0,
    use_bias=True,
    output_shape=None,
    attention_axes=None,
    kernel_initializer="glorot_uniform",
    bias_initializer="zeros",
    kernel_regularizer=None,
    bias_regularizer=None,
    activity_regularizer=None,
    kernel_constraint=None,
    bias_constraint=None,
    **kwargs
)

MultiHeadAttention 层。

这是对论文“Attention is all you Need”(Vaswani et al., 2017) 中描述的多头注意力机制的实现。如果 query, key,value 相同,则这是自注意力机制。query 中的每个时间步都会关注 key 中相应的序列,并返回一个固定宽度的向量。

该层首先对 query, keyvalue 进行投影。这些(实际上)是一个长度为 num_attention_heads 的张量列表,其对应形状分别为 (batch_size, <query dimensions>, key_dim), (batch_size, <key/value dimensions>, key_dim), (batch_size, <key/value dimensions>, value_dim)

然后,对 query 和 key 张量进行点乘并缩放。然后对其进行 softmax 操作以获得注意力概率。然后,value 张量根据这些概率进行插值,再拼接回一个单个张量。

最后,最后一个维度为 value_dim 的结果张量可以进行线性投影并输出。

在自定义层中使用 MultiHeadAttention 时,自定义层必须实现自己的 build() 方法,并在其中调用 MultiHeadAttention_build_from_signature() 方法。这使得模型加载时可以正确恢复权重。

示例

使用注意力掩码对两个序列输入执行 1D 交叉注意力。返回每个头的额外注意力权重。

>>> layer = MultiHeadAttention(num_heads=2, key_dim=2)
>>> target = tf.keras.Input(shape=[8, 16])
>>> source = tf.keras.Input(shape=[4, 16])
>>> output_tensor, weights = layer(target, source,
...                                return_attention_scores=True)
>>> print(output_tensor.shape)
(None, 8, 16)
>>> print(weights.shape)
(None, 2, 8, 4)

对一个 5D 输入张量在其轴 2 和轴 3 上执行 2D 自注意力。

>>> layer = MultiHeadAttention(
...     num_heads=2, key_dim=2, attention_axes=(2, 3))
>>> input_tensor = tf.keras.Input(shape=[5, 3, 4, 16])
>>> output_tensor = layer(input_tensor, input_tensor)
>>> print(output_tensor.shape)
(None, 5, 3, 4, 16)

参数

  • num_heads: 注意力头的数量。
  • key_dim: 用于 query 和 key 的每个注意力头的大小。
  • value_dim: 用于 value 的每个注意力头的大小。
  • dropout: Dropout 概率。
  • use_bias: 布尔值,表示全连接层是否使用偏置向量/矩阵。
  • output_shape: 输出张量的预期形状(不包括批次和序列维度)。如果未指定,则投影回 query 特征维度(即 query 输入的最后一个维度)。
  • attention_axes: 应用注意力机制的轴。None 表示对除批次、头和特征之外的所有轴应用注意力。
  • kernel_initializer: 全连接层权重的初始化器。
  • bias_initializer: 全连接层偏置的初始化器。
  • kernel_regularizer: 全连接层权重的正则化器。
  • bias_regularizer: 全连接层偏置的正则化器。
  • activity_regularizer: 全连接层输出的正则化器。
  • kernel_constraint: 全连接层权重的约束。
  • bias_constraint: 全连接层偏置的约束。

调用参数

  • query: 形状为 (B, T, dim) 的 Query Tensor
  • value: 形状为 (B, S, dim) 的 Value Tensor
  • key: 可选的 key Tensor,形状为 (B, S, dim)。如果未提供,则将 value 用于 keyvalue,这是最常见的情况。
  • attention_mask: 形状为 (B, T, S) 的布尔掩码,用于阻止对某些位置的注意力。布尔掩码指定了哪些 query 元素可以关注哪些 key 元素,1 表示关注,0 表示不关注。对于缺失的批次维度和头维度,可以进行广播。
  • return_attention_scores: 布尔值,表示如果为 True,则输出应为 (attention_output, attention_scores),如果为 False,则输出为 attention_output。默认为 False
  • training: Python 布尔值,指示该层应在训练模式(添加 dropout)还是推断模式(无 dropout)下运行。如果没有父层,则默认为父层/模型的训练模式,或者为 False(推断模式)。
  • use_causal_mask: 布尔值,指示是否应用因果掩码以防止 token 关注未来的 token(例如,在解码器 Transformer 中使用)。

返回值

  • attention_output: 计算结果,形状为 (B, T, E),其中 T 表示目标序列形状,如果 output_shapeNone,则 E 为 query 输入的最后一个维度。否则,多头输出会投影到由 output_shape 指定的形状。
  • attention_scores: [可选] 在注意力轴上的多头注意力系数。