CachedMultiHeadAttention
类keras_hub.layers.CachedMultiHeadAttention(
num_heads,
key_dim,
value_dim=None,
dropout=0.0,
use_bias=True,
output_shape=None,
attention_axes=None,
flash_attention=None,
kernel_initializer="glorot_uniform",
bias_initializer="zeros",
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
seed=None,
**kwargs
)
支持缓存的多头注意力层。
该层适用于自回归解码。可用于缓存解码器自注意力和交叉注意力。前向传递可以有三种模式:
cache_update_index
为 None)。在这种情况下,将使用缓存的键/值投影,输入值将被忽略。cache_update_index
不为 None)。在这种情况下,将使用输入计算新的键/值投影,并拼接(splice)到缓存中指定的位置。请注意,缓存仅在推理期间有用,不应在训练期间使用。
我们在下面使用 B
、T
、S
的标记,其中 B
是批次维度,T
是目标序列长度,S
是源序列长度。请注意,在生成式解码期间,T
通常为 1(您正在生成长度为 1 的目标序列以预测下一个 token)。
调用参数
(B, T, dim)
的 Query Tensor
。(B, S*, dim)
的 Value Tensor
。如果 cache
为 None,则 S*
必须等于 S
并与 attention_mask
的形状匹配。如果 cache
不为 None
,则 S*
可以是小于 S
的任何长度,计算出的 value 将被拼接(splice)到 cache
中 cache_update_index
的位置。Tensor
,形状为 (B, S*, dim)
。如果 cache
为 None
,则 S*
必须等于 S
并与 attention_mask
的形状匹配。如果 cache
不为 None
,则 S*
可以是小于 S
的任何长度,计算出的 value 将被拼接(splice)到 cache
中 cache_update_index
的位置。(B, T, S)
的布尔掩码。attention_mask
阻止对某些位置的注意力。布尔掩码指定哪些 query 元素可以关注哪些 key 元素,1 表示注意力,0 表示无注意力。对于缺失的批次维度和头维度,可以发生广播(Broadcasting)。[B, 2, S, num_heads, key_dims]
,其中 S
必须与 attention_mask
的形状一致。此参数旨在用于生成期间,以避免重新计算中间状态。cache
的索引(通常是在运行生成时正在处理的当前 token 的索引)。如果在设置了 cache
的同时 cache_update_index=None
,则缓存不会更新。返回
一个 (attention_output, cache)
元组。attention_output
是计算结果,形状为 (B, T, dim)
,其中 T
表示目标序列的形状,如果 output_shape
为 None
,则 dim
是 query 输入的最后一个维度。否则,多头输出将被投影到 output_shape
指定的形状。cache
是更新后的缓存。