CachedMultiHeadAttention 类keras_hub.layers.CachedMultiHeadAttention(
num_heads,
key_dim,
value_dim=None,
dropout=0.0,
use_bias=True,
output_shape=None,
attention_axes=None,
flash_attention=None,
kernel_initializer="glorot_uniform",
bias_initializer="zeros",
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
seed=None,
**kwargs
)
支持缓存的多头注意力层。
此层适合在自回归解码中使用。它可以用于缓存解码器的自注意力(self-attention)和交叉注意力(cross-attention)。前向传播可以有以下三种模式之一:
cache_update_index 为 None)。在这种情况下,将使用缓存的键/值投影,并忽略输入值。cache_update_index 不为 None)。在这种情况下,将使用输入计算新的键/值投影,并将其插入到指定索引的缓存中。请注意,缓存仅在推理时有用,不应在训练时使用。
下面我们使用 B、T、S 的符号,其中 B 是批次维度,T 是目标序列长度,S 是源序列长度。请注意,在生成式解码时,T 通常为 1(你正在生成长度为 1 的目标序列来预测下一个 token)。
调用参数
Tensor,形状为 (B, T, dim)。Tensor,形状为 (B, S*, dim)。如果 cache 为 None,则 S* 必须等于 S 并与 attention_mask 的形状匹配。如果 cache 不为 None,则 S* 可以是小于 S 的任何长度,并且计算出的值将插入到 cache 的 cache_update_index 处。Tensor,形状为 (B, S*, dim)。如果 cache 为 None,则 S* 必须等于 S 并与 attention_mask 的形状匹配。如果 cache 不为 None,则 S* 可以是小于 S 的任何长度,并且计算出的值将插入到 cache 的 cache_update_index 处。(B, T, S)。attention_mask 可防止对某些位置进行注意力。布尔掩码指定了哪些查询元素可以关注哪些键元素,1 表示注意力,0 表示无注意力。对于缺失的批次维度和头维度,可以发生广播。Tensor。键/值缓存,形状为 [B, 2, S, num_heads, key_dims],其中 S 必须与 attention_mask 的形状一致。此参数旨在生成过程中使用,以避免重新计算中间状态。Tensor,用于更新 cache 的索引(通常是在运行生成时正在处理的当前 token 的索引)。如果 cache_update_index=None 且 cache 已设置,则缓存不会被更新。返回
一个 (attention_output, cache) 元组。attention_output 是计算结果,形状为 (B, T, dim),其中 T 用于目标序列形状,dim 是查询输入的最后一个维度(如果 output_shape 为 None)。否则,多头输出将被投影到 output_shape 指定的形状。cache 是更新后的缓存。