GroupedQueryAttention 类keras.layers.GroupQueryAttention(
head_dim,
num_query_heads,
num_key_value_heads,
dropout=0.0,
use_bias=True,
flash_attention=None,
kernel_initializer="glorot_uniform",
bias_initializer="zeros",
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
seed=None,
**kwargs
)
分组查询注意力层。
这是 Ainslie 等人于 2023 年提出的分组查询注意力(grouped-query attention)的实现。其中 num_key_value_heads 表示组的数量,将 num_key_value_heads 设置为 1 等同于多查询注意力(multi-query attention),而当 num_key_value_heads 等于 num_query_heads 时,则等同于多头注意力(multi-head attention)。
该层首先投影 query、key 和 value 张量。然后,key 和 value 会被重复以匹配 query 的头数。
然后,对 query 进行缩放并与 key 张量进行点积运算。这些运算会经过 softmax 处理以获得注意力概率。然后,通过这些概率对 value 张量进行插值,并将其连接回单个张量。
参数
None,该层会尝试使用 flash attention 来实现更快速、更节省内存的注意力计算(如果可能)。此行为可以通过 keras.config.enable_flash_attention() 或 keras.config.disable_flash_attention() 进行配置。调用参数
(batch_dim, target_seq_len, feature_dim),其中 batch_dim 是批次大小,target_seq_len 是目标序列的长度,feature_dim 是特征的维度。(batch_dim, source_seq_len, feature_dim),其中 batch_dim 是批次大小,source_seq_len 是源序列的长度,feature_dim 是特征的维度。(batch_dim, source_seq_len, feature_dim)。如果未提供,则将 value 同时用于 key 和 value,这是最常见的情况。(batch_dim, target_seq_len, source_seq_len),用于防止注意力集中在某些位置。布尔掩码指定了哪些查询元素可以关注哪些键元素,其中 1 表示注意力,0 表示无注意力。缺失的批次维度和头维度可以进行广播。True,输出应为 (attention_output, attention_scores),如果为 False,则为 attention_output。默认为 False。False(推理模式)。返回
(batch_dim, target_seq_len, feature_dim),其中 target_seq_len 是目标序列长度,feature_dim 是查询输入的最后一个维度。(batch_dim, num_query_heads, target_seq_len, source_seq_len)。