Keras 3 API 文档 / 层 API / 注意力层 / GroupQueryAttention

GroupQueryAttention

[源]

GroupedQueryAttention

keras.layers.GroupQueryAttention(
    head_dim,
    num_query_heads,
    num_key_value_heads,
    dropout=0.0,
    use_bias=True,
    flash_attention=None,
    kernel_initializer="glorot_uniform",
    bias_initializer="zeros",
    kernel_regularizer=None,
    bias_regularizer=None,
    activity_regularizer=None,
    kernel_constraint=None,
    bias_constraint=None,
    seed=None,
    **kwargs
)

分组查询注意力(Grouped Query Attention)层。

这是由 Ainslie et al., 2023 提出的分组查询注意力(grouped-query attention)的实现。此处 num_key_value_heads 表示分组数量,将 num_key_value_heads 设置为 1 等同于多查询注意力(multi-query attention),当 num_key_value_heads 等于 num_query_heads 时,它等同于多头注意力(multi-head attention)。

该层首先对 querykeyvalue 张量进行投影。然后,重复 keyvalue 以匹配 query 的头数。

接着,对 query 进行缩放并与 key 张量进行点积。对结果进行 Softmax 运算以获得注意力概率。然后,根据这些概率对 value 张量进行插值,并拼接回单个张量。

参数

  • head_dim: 每个注意力头的大小。
  • num_query_heads: 查询注意力头的数量。
  • num_key_value_heads: key 和 value 注意力头的数量。
  • dropout: Dropout 概率。
  • use_bias: 布尔值,表示全连接层是否使用偏置向量/矩阵。
  • flash_attention: 如果为 None,则层在可能的情况下会尝试使用 Flash Attention 以进行更快、更省内存的注意力计算。此行为可通过 keras.config.enable_flash_attention()keras.config.disable_flash_attention() 进行配置。
  • kernel_initializer: 全连接层权重(kernel)的初始化器。
  • bias_initializer: 全连接层偏置的初始化器。
  • kernel_regularizer: 全连接层权重(kernel)的正则化器。
  • bias_regularizer: 全连接层偏置的正则化器。
  • activity_regularizer: 全连接层输出的正则化器。
  • kernel_constraint: 全连接层权重(kernel)的约束。
  • bias_constraint: 全连接层权重(kernel)的约束。
  • seed: Dropout 层的可选整数种子。

Call 参数

  • query: query 张量,形状为 (batch_dim, target_seq_len, feature_dim),其中 batch_dim 是批量大小,target_seq_len 是目标序列长度,feature_dim 是特征维度。
  • value: value 张量,形状为 (batch_dim, source_seq_len, feature_dim),其中 batch_dim 是批量大小,source_seq_len 是源序列长度,feature_dim 是特征维度。
  • key: 可选的 key 张量,形状为 (batch_dim, source_seq_len, feature_dim)。如果未给出,将使用 value 作为 keyvalue,这是最常见的情况。
  • attention_mask: 一个布尔掩码,形状为 (batch_dim, target_seq_len, source_seq_len),用于阻止对某些位置的注意力计算。该布尔掩码指定哪些 query 元素可以关注哪些 key 元素,其中 1 表示允许注意力计算,0 表示不允许。对于缺失的批量维度和头维度,可以进行广播。
  • return_attention_scores: 一个布尔值,指示当为 True 时输出是否应为 (attention_output, attention_scores),当为 False 时输出是否应为 attention_output。默认为 False
  • training: 一个 Python 布尔值,指示层应在训练模式(添加 dropout)还是推理模式(无 dropout)下运行。如果没有父层,则将采用父层/模型的训练模式;否则为 False(推理)。
  • use_causal_mask: 一个布尔值,指示是否应用因果掩码(causal mask)以防止 token 关注未来的 token (例如,用于解码器 Transformer)。

返回值

  • attention_output: 计算结果,形状为 (batch_dim, target_seq_len, feature_dim),其中 target_seq_len 是目标序列长度,feature_dim 是 query 输入的最后一个维度。
  • attention_scores: (可选) 注意力系数,形状为 (batch_dim, num_query_heads, target_seq_len, source_seq_len)