GroupedQueryAttention
类keras.layers.GroupQueryAttention(
head_dim,
num_query_heads,
num_key_value_heads,
dropout=0.0,
use_bias=True,
flash_attention=None,
kernel_initializer="glorot_uniform",
bias_initializer="zeros",
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
seed=None,
**kwargs
)
分组查询注意力层。
这是 Ainslie 等人 2023 年提出的分组查询注意力的实现,参见 论文。 其中,num_key_value_heads
表示组的数量,将 num_key_value_heads
设置为 1 等价于多查询注意力,当 num_key_value_heads
等于 num_query_heads
时,等价于多头注意力。
此层首先投影 query
、key
和 value
张量。然后,key
和 value
会重复以匹配 query
的头数。
接着,对 query
进行缩放,并与 key
张量进行点积。这些结果会经过 softmax 函数得到注意力概率。然后,通过这些概率对 value
张量进行插值,并将其连接回单个张量。
参数
None
,则该层会在可能的情况下尝试使用 flash attention 来实现更快、更节省内存的注意力计算。此行为可以通过 keras.config.enable_flash_attention()
或 keras.config.disable_flash_attention()
进行配置。调用参数
(batch_dim, target_seq_len, feature_dim)
,其中 batch_dim
是批次大小,target_seq_len
是目标序列的长度,feature_dim
是特征的维度。(batch_dim, source_seq_len, feature_dim)
,其中 batch_dim
是批次大小,source_seq_len
是源序列的长度,feature_dim
是特征的维度。(batch_dim, source_seq_len, feature_dim)
。如果未给出,则会同时使用 value
作为 key
和 value
,这是最常见的情况。(batch_dim, target_seq_len, source_seq_len)
的布尔掩码,用于阻止对某些位置的注意力。布尔掩码指定哪些查询元素可以关注哪些键元素,其中 1 表示关注,0 表示不关注。对于缺失的批次维度和头维度,可以进行广播。(attention_output, attention_scores)
(如果为 True
),或 attention_output
(如果为 False
)。默认为 False
。False
(推理)。返回值
(batch_dim, target_seq_len, feature_dim)
,其中 target_seq_len
是目标序列长度,feature_dim
是查询输入的最后一个维度。(batch_dim, num_query_heads, target_seq_len, source_seq_len)
。