Keras 3 API文档 / Layers API / Attention layers / GroupQueryAttention

分组查询注意力 (GroupQueryAttention)

[源代码]

GroupedQueryAttention

keras.layers.GroupQueryAttention(
    head_dim,
    num_query_heads,
    num_key_value_heads,
    dropout=0.0,
    use_bias=True,
    flash_attention=None,
    kernel_initializer="glorot_uniform",
    bias_initializer="zeros",
    kernel_regularizer=None,
    bias_regularizer=None,
    activity_regularizer=None,
    kernel_constraint=None,
    bias_constraint=None,
    seed=None,
    **kwargs
)

分组查询注意力层。

这是 Ainslie 等人于 2023 年提出的分组查询注意力(grouped-query attention)的实现。其中 num_key_value_heads 表示组的数量,将 num_key_value_heads 设置为 1 等同于多查询注意力(multi-query attention),而当 num_key_value_heads 等于 num_query_heads 时,则等同于多头注意力(multi-head attention)。

该层首先投影 querykeyvalue 张量。然后,keyvalue 会被重复以匹配 query 的头数。

然后,对 query 进行缩放并与 key 张量进行点积运算。这些运算会经过 softmax 处理以获得注意力概率。然后,通过这些概率对 value 张量进行插值,并将其连接回单个张量。

参数

  • head_dim: 每个注意力头的尺寸。
  • num_query_heads: 查询注意力头的数量。
  • num_key_value_heads: 键和值注意力头的数量。
  • dropout:Dropout 概率。
  • use_bias:布尔值,表示密集层是否使用偏置向量/矩阵。
  • flash_attention: 如果为 None,该层会尝试使用 flash attention 来实现更快速、更节省内存的注意力计算(如果可能)。此行为可以通过 keras.config.enable_flash_attention()keras.config.disable_flash_attention() 进行配置。
  • kernel_initializer:密集层核的初始化器。
  • bias_initializer:密集层偏置的初始化器。
  • kernel_regularizer:密集层核的正则化器。
  • bias_regularizer:密集层偏置的正则化器。
  • activity_regularizer:密集层活动的正则化器。
  • kernel_constraint:密集层核的约束。
  • bias_constraint:密集层核的约束。
  • seed: 可选的整数,用于初始化 dropout 层的随机种子。

调用参数

  • query: 查询张量,形状为 (batch_dim, target_seq_len, feature_dim),其中 batch_dim 是批次大小,target_seq_len 是目标序列的长度,feature_dim 是特征的维度。
  • value: 值张量,形状为 (batch_dim, source_seq_len, feature_dim),其中 batch_dim 是批次大小,source_seq_len 是源序列的长度,feature_dim 是特征的维度。
  • key: 可选的键张量,形状为 (batch_dim, source_seq_len, feature_dim)。如果未提供,则将 value 同时用于 keyvalue,这是最常见的情况。
  • attention_mask: 一个布尔掩码,形状为 (batch_dim, target_seq_len, source_seq_len),用于防止注意力集中在某些位置。布尔掩码指定了哪些查询元素可以关注哪些键元素,其中 1 表示注意力,0 表示无注意力。缺失的批次维度和头维度可以进行广播。
  • return_attention_scores:一个布尔值,指示如果为 True,输出应为 (attention_output, attention_scores),如果为 False,则为 attention_output。默认为 False
  • training: Python 布尔值,指示该层是应该以训练模式(添加 dropout)运行还是以推理模式(无 dropout)运行。如果存在父层/模型,则使用其训练模式;否则,默认为 False(推理模式)。
  • use_causal_mask:一个布尔值,指示是否应用因果掩码以防止 token 关注未来的 token(例如,在解码器 Transformer 中使用)。

返回

  • attention_output: 计算结果,形状为 (batch_dim, target_seq_len, feature_dim),其中 target_seq_len 是目标序列长度,feature_dim 是查询输入的最后一个维度。
  • attention_scores: (可选) 注意力系数,形状为 (batch_dim, num_query_heads, target_seq_len, source_seq_len)