Keras 3 API 文档 / 层 API / 注意力层 / GroupQueryAttention

GroupQueryAttention

[源代码]

GroupedQueryAttention

keras.layers.GroupQueryAttention(
    head_dim,
    num_query_heads,
    num_key_value_heads,
    dropout=0.0,
    use_bias=True,
    flash_attention=None,
    kernel_initializer="glorot_uniform",
    bias_initializer="zeros",
    kernel_regularizer=None,
    bias_regularizer=None,
    activity_regularizer=None,
    kernel_constraint=None,
    bias_constraint=None,
    seed=None,
    **kwargs
)

分组查询注意力层。

这是 Ainslie 等人在 2023 年提出的分组查询注意力的实现。这里 num_key_value_heads 表示组的数量,将 num_key_value_heads 设置为 1 等同于多查询注意力,当 num_key_value_heads 等于 num_query_heads 时,它等同于多头注意力。

此层首先投影 querykeyvalue 张量。然后,重复 keyvalue 以匹配 query 的头数。

然后,将 query 缩放并与 key 张量进行点积运算。这些结果通过 softmax 函数获得注意力概率。然后,值张量通过这些概率进行插值,并连接回单个张量。

参数

  • head_dim: 每个注意力头的大小。
  • num_query_heads: 查询注意力头的数量。
  • num_key_value_heads: 键和值注意力头的数量。
  • dropout: Dropout 概率。
  • use_bias: 布尔值,指示密集层是否使用偏置向量/矩阵。
  • flash_attention: 如果为 None,则该层会尝试在可能的情况下使用 flash attention 以获得更快、更节省内存的注意力计算。可以使用 keras.config.enable_flash_attention()keras.config.disable_flash_attention() 配置此行为。
  • kernel_initializer: 密集层内核的初始化器。
  • bias_initializer: 密集层偏置的初始化器。
  • kernel_regularizer: 密集层内核的正则化器。
  • bias_regularizer: 密集层偏置的正则化器。
  • activity_regularizer: 密集层活动的正则化器。
  • kernel_constraint: 密集层内核的约束。
  • bias_constraint: 密集层偏置的约束。
  • seed: 可选整数,用于为 dropout 层设定种子。

调用参数

  • query: 形状为 (batch_dim, target_seq_len, feature_dim) 的查询张量,其中 batch_dim 是批大小,target_seq_len 是目标序列的长度,feature_dim 是特征的维度。
  • value: 形状为 (batch_dim, source_seq_len, feature_dim) 的值张量,其中 batch_dim 是批大小,source_seq_len 是源序列的长度,feature_dim 是特征的维度。
  • key: 可选的键张量,形状为 (batch_dim, source_seq_len, feature_dim)。如果未给出,将使用 value 同时作为 keyvalue,这是最常见的情况。
  • attention_mask: 形状为 (batch_dim, target_seq_len, source_seq_len) 的布尔掩码,用于阻止对某些位置的注意力。布尔掩码指定哪些查询元素可以关注哪些键元素,其中 1 表示关注,0 表示不关注。广播可以发生在缺失的批次维度和头维度上。
  • return_attention_scores: 布尔值,指示如果为 True,输出应为 (attention_output, attention_scores),如果为 False,则为 attention_output。默认为 False
  • training: Python 布尔值,指示层应在训练模式(添加 dropout)还是推理模式(无 dropout)下运行。如果没有父层,将使用父层/模型的训练模式或 False(推理)。
  • use_causal_mask: 布尔值,指示是否应用因果掩码以防止 token 关注未来的 token(例如,在解码器 Transformer 中使用)。

返回

  • attention_output: 计算结果,形状为 (batch_dim, target_seq_len, feature_dim),其中 target_seq_len 是目标序列长度,feature_dim 是查询输入最后一个维度。
  • attention_scores: (可选) 注意力系数,形状为 (batch_dim, num_query_heads, target_seq_len, source_seq_len)