Keras 3 API 文档 / 层 API / 注意力层 / GroupQueryAttention

GroupQueryAttention

[源代码]

GroupedQueryAttention

keras.layers.GroupQueryAttention(
    head_dim,
    num_query_heads,
    num_key_value_heads,
    dropout=0.0,
    use_bias=True,
    flash_attention=None,
    kernel_initializer="glorot_uniform",
    bias_initializer="zeros",
    kernel_regularizer=None,
    bias_regularizer=None,
    activity_regularizer=None,
    kernel_constraint=None,
    bias_constraint=None,
    seed=None,
    **kwargs
)

分组查询注意力层。

这是 Ainslie 等人 2023 年提出的分组查询注意力的实现,参见 论文。 其中,num_key_value_heads 表示组的数量,将 num_key_value_heads 设置为 1 等价于多查询注意力,当 num_key_value_heads 等于 num_query_heads 时,等价于多头注意力。

此层首先投影 querykeyvalue 张量。然后,keyvalue 会重复以匹配 query 的头数。

接着,对 query 进行缩放,并与 key 张量进行点积。这些结果会经过 softmax 函数得到注意力概率。然后,通过这些概率对 value 张量进行插值,并将其连接回单个张量。

参数

  • head_dim:每个注意力头的尺寸。
  • num_query_heads:查询注意力头的数量。
  • num_key_value_heads:键和值注意力头的数量。
  • dropout:丢弃概率。
  • use_bias:布尔值,指示密集层是否使用偏置向量/矩阵。
  • flash_attention:如果为 None,则该层会在可能的情况下尝试使用 flash attention 来实现更快、更节省内存的注意力计算。此行为可以通过 keras.config.enable_flash_attention()keras.config.disable_flash_attention() 进行配置。
  • kernel_initializer:密集层核的初始化器。
  • bias_initializer:密集层偏置的初始化器。
  • kernel_regularizer:密集层核的正则化器。
  • bias_regularizer:密集层偏置的正则化器。
  • activity_regularizer:密集层激活的正则化器。
  • kernel_constraint:密集层核的约束。
  • bias_constraint:密集层核的约束。
  • seed:可选整数,用于为丢弃层设置种子。

调用参数

  • query:查询张量,形状为 (batch_dim, target_seq_len, feature_dim),其中 batch_dim 是批次大小,target_seq_len 是目标序列的长度,feature_dim 是特征的维度。
  • value:值张量,形状为 (batch_dim, source_seq_len, feature_dim),其中 batch_dim 是批次大小,source_seq_len 是源序列的长度,feature_dim 是特征的维度。
  • key:可选的键张量,形状为 (batch_dim, source_seq_len, feature_dim)。如果未给出,则会同时使用 value 作为 keyvalue,这是最常见的情况。
  • attention_mask:形状为 (batch_dim, target_seq_len, source_seq_len) 的布尔掩码,用于阻止对某些位置的注意力。布尔掩码指定哪些查询元素可以关注哪些键元素,其中 1 表示关注,0 表示不关注。对于缺失的批次维度和头维度,可以进行广播。
  • return_attention_scores:布尔值,指示输出是否应为 (attention_output, attention_scores)(如果为 True),或 attention_output(如果为 False)。默认为 False
  • training:Python 布尔值,指示层是否应该处于训练模式(添加丢弃)或推理模式(不添加丢弃)。如果存在父层,则会使用父层/模型的训练模式,否则使用 False(推理)。
  • use_causal_mask:布尔值,指示是否应用因果掩码以防止标记关注未来的标记(例如,在解码器 Transformer 中使用)。

返回值

  • attention_output:计算结果,形状为 (batch_dim, target_seq_len, feature_dim),其中 target_seq_len 是目标序列长度,feature_dim 是查询输入的最后一个维度。
  • attention_scores:(可选)注意力系数,形状为 (batch_dim, num_query_heads, target_seq_len, source_seq_len)