GroupedQueryAttention
类keras.layers.GroupQueryAttention(
head_dim,
num_query_heads,
num_key_value_heads,
dropout=0.0,
use_bias=True,
flash_attention=None,
kernel_initializer="glorot_uniform",
bias_initializer="zeros",
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
seed=None,
**kwargs
)
分组查询注意力(Grouped Query Attention)层。
这是由 Ainslie et al., 2023 提出的分组查询注意力(grouped-query attention)的实现。此处 num_key_value_heads
表示分组数量,将 num_key_value_heads
设置为 1 等同于多查询注意力(multi-query attention),当 num_key_value_heads
等于 num_query_heads
时,它等同于多头注意力(multi-head attention)。
该层首先对 query
、key
和 value
张量进行投影。然后,重复 key
和 value
以匹配 query
的头数。
接着,对 query
进行缩放并与 key
张量进行点积。对结果进行 Softmax 运算以获得注意力概率。然后,根据这些概率对 value 张量进行插值,并拼接回单个张量。
参数
None
,则层在可能的情况下会尝试使用 Flash Attention 以进行更快、更省内存的注意力计算。此行为可通过 keras.config.enable_flash_attention()
或 keras.config.disable_flash_attention()
进行配置。Call 参数
(batch_dim, target_seq_len, feature_dim)
,其中 batch_dim
是批量大小,target_seq_len
是目标序列长度,feature_dim
是特征维度。(batch_dim, source_seq_len, feature_dim)
,其中 batch_dim
是批量大小,source_seq_len
是源序列长度,feature_dim
是特征维度。(batch_dim, source_seq_len, feature_dim)
。如果未给出,将使用 value
作为 key
和 value
,这是最常见的情况。(batch_dim, target_seq_len, source_seq_len)
,用于阻止对某些位置的注意力计算。该布尔掩码指定哪些 query 元素可以关注哪些 key 元素,其中 1 表示允许注意力计算,0 表示不允许。对于缺失的批量维度和头维度,可以进行广播。True
时输出是否应为 (attention_output, attention_scores)
,当为 False
时输出是否应为 attention_output
。默认为 False
。False
(推理)。返回值
(batch_dim, target_seq_len, feature_dim)
,其中 target_seq_len
是目标序列长度,feature_dim
是 query 输入的最后一个维度。(batch_dim, num_query_heads, target_seq_len, source_seq_len)
。