Keras 3 API 文档 / 层 API / 注意力层 / MultiHeadAttention 层

MultiHeadAttention 层

[源代码]

MultiHeadAttention

keras.layers.MultiHeadAttention(
    num_heads,
    key_dim,
    value_dim=None,
    dropout=0.0,
    use_bias=True,
    output_shape=None,
    attention_axes=None,
    flash_attention=None,
    kernel_initializer="glorot_uniform",
    bias_initializer="zeros",
    kernel_regularizer=None,
    bias_regularizer=None,
    activity_regularizer=None,
    kernel_constraint=None,
    bias_constraint=None,
    seed=None,
    **kwargs
)

MultiHeadAttention 层。

这是论文“Attention is all you Need”Vaswani 等人,2017 中描述的多头注意力(multi-headed attention)的实现。如果 querykeyvalue 相同,则这是自注意力(self-attention)。query 中的每个时间步都会关注 key 中对应的序列,并返回一个固定宽度的向量。

此层首先对 querykeyvalue 进行投影。这些(实际上)是长度为 num_attention_heads 的张量列表,其对应形状分别为 (batch_size, <query dimensions>, key_dim)(batch_size, <key/value dimensions>, key_dim)(batch_size, <key/value dimensions>, value_dim)

然后,对 query 和 key 张量进行点积并缩放。这些张量经过 softmax 处理以获得注意力概率。然后,值张量由这些概率进行插值,再拼接回单个张量。

最后,最后一维为 value_dim 的结果张量可以进行线性投影并返回。

参数

  • num_heads:注意力头的数量。
  • key_dim:query 和 key 的每个注意力头的大小。
  • value_dim:value 的每个注意力头的大小。
  • dropout:Dropout 概率。
  • use_bias:布尔值,表示全连接层是否使用偏置向量/矩阵。
  • output_shape:输出张量除了批次和序列维度以外的预期形状。如果未指定,则会投影回 query 特征维度(query 输入的最后一个维度)。
  • attention_axes:应用注意力的轴。None 表示在除了批次、头和特征之外的所有轴上应用注意力。
  • flash_attention:如果为 None,则该层会在可能时尝试使用 Flash Attention 以实现更快、更节省内存的注意力计算。可以使用 keras.config.enable_flash_attention()keras.config.disable_flash_attention() 配置此行为。
  • kernel_initializer:全连接层权重的初始化器。
  • bias_initializer:全连接层偏置的初始化器。
  • kernel_regularizer:全连接层权重的正则化器。
  • bias_regularizer:全连接层偏置的正则化器。
  • activity_regularizer:全连接层输出的正则化器。
  • kernel_constraint:全连接层权重的约束。
  • bias_constraint:全连接层偏置的约束。
  • seed:用于设置 Dropout 层种子的可选整数。

调用参数

  • query:形状为 (B, T, dim) 的 query 张量,其中 B 是批次大小,T 是目标序列长度,dim 是特征维度。
  • value:形状为 (B, S, dim) 的 value 张量,其中 B 是批次大小,S 是源序列长度,dim 是特征维度。
  • key:形状为 (B, S, dim) 的可选 key 张量。如果未给出,则将 value 同时用于 keyvalue,这是最常见的情况。
  • attention_mask:形状为 (B, T, S) 的布尔掩码,用于阻止对某些位置的注意力。布尔掩码指定了哪些 query 元素可以关注哪些 key 元素,1 表示关注,0 表示不关注。对于缺失的批次维度和头维度,可能会发生广播。
  • return_attention_scores:一个布尔值,指示如果为 True,输出是否应为 (attention_output, attention_scores);如果为 False,则仅为 attention_output。默认为 False
  • training:Python 布尔值,指示该层是否应在训练模式下(添加 dropout)或推理模式下(无 dropout)运行。如果没有父层,则将遵循父层/模型的训练模式,或使用 False(推理模式)。
  • use_causal_mask:一个布尔值,指示是否应用因果掩码以阻止 token 关注未来的 token(例如,用于解码器 Transformer)。

返回

  • attention_output:计算结果,形状为 (B, T, E),其中 T 表示目标序列形状,如果 output_shapeNone,则 E 是 query 输入的最后一个维度。否则,多头输出将被投影到由 output_shape 指定的形状。
  • attention_scores:(可选)注意力轴上的多头注意力系数。