GroupedQueryAttention
类keras.layers.GroupQueryAttention(
head_dim,
num_query_heads,
num_key_value_heads,
dropout=0.0,
use_bias=True,
flash_attention=None,
kernel_initializer="glorot_uniform",
bias_initializer="zeros",
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
seed=None,
**kwargs
)
分组查询注意力层。
这是 Ainslie 等人在 2023 年提出的分组查询注意力的实现。这里 num_key_value_heads
表示组的数量,将 num_key_value_heads
设置为 1 等同于多查询注意力,当 num_key_value_heads
等于 num_query_heads
时,它等同于多头注意力。
此层首先投影 query
、key
和 value
张量。然后,重复 key
和 value
以匹配 query
的头数。
然后,将 query
缩放并与 key
张量进行点积运算。这些结果通过 softmax 函数获得注意力概率。然后,值张量通过这些概率进行插值,并连接回单个张量。
参数
None
,则该层会尝试在可能的情况下使用 flash attention 以获得更快、更节省内存的注意力计算。可以使用 keras.config.enable_flash_attention()
或 keras.config.disable_flash_attention()
配置此行为。调用参数
(batch_dim, target_seq_len, feature_dim)
的查询张量,其中 batch_dim
是批大小,target_seq_len
是目标序列的长度,feature_dim
是特征的维度。(batch_dim, source_seq_len, feature_dim)
的值张量,其中 batch_dim
是批大小,source_seq_len
是源序列的长度,feature_dim
是特征的维度。(batch_dim, source_seq_len, feature_dim)
。如果未给出,将使用 value
同时作为 key
和 value
,这是最常见的情况。(batch_dim, target_seq_len, source_seq_len)
的布尔掩码,用于阻止对某些位置的注意力。布尔掩码指定哪些查询元素可以关注哪些键元素,其中 1 表示关注,0 表示不关注。广播可以发生在缺失的批次维度和头维度上。True
,输出应为 (attention_output, attention_scores)
,如果为 False
,则为 attention_output
。默认为 False
。False
(推理)。返回
(batch_dim, target_seq_len, feature_dim)
,其中 target_seq_len
是目标序列长度,feature_dim
是查询输入最后一个维度。(batch_dim, num_query_heads, target_seq_len, source_seq_len)
。