KerasHub:预训练模型 / API 文档 / 建模层 / CachedMultiHeadAttention 层

CachedMultiHeadAttention 层

[源]

CachedMultiHeadAttention

keras_hub.layers.CachedMultiHeadAttention(
    num_heads,
    key_dim,
    value_dim=None,
    dropout=0.0,
    use_bias=True,
    output_shape=None,
    attention_axes=None,
    flash_attention=None,
    kernel_initializer="glorot_uniform",
    bias_initializer="zeros",
    kernel_regularizer=None,
    bias_regularizer=None,
    activity_regularizer=None,
    kernel_constraint=None,
    bias_constraint=None,
    seed=None,
    **kwargs
)

支持缓存的多头注意力层。

该层适用于自回归解码。可用于缓存解码器自注意力和交叉注意力。前向传递可以有三种模式:

  • 无缓存,与普通多头注意力相同。
  • 静态缓存 (cache_update_index 为 None)。在这种情况下,将使用缓存的键/值投影,输入值将被忽略。
  • 更新缓存 (cache_update_index 不为 None)。在这种情况下,将使用输入计算新的键/值投影,并拼接(splice)到缓存中指定的位置。

请注意,缓存仅在推理期间有用,不应在训练期间使用。

我们在下面使用 BTS 的标记,其中 B 是批次维度,T 是目标序列长度,S 是源序列长度。请注意,在生成式解码期间,T 通常为 1(您正在生成长度为 1 的目标序列以预测下一个 token)。

调用参数

  • query:形状为 (B, T, dim) 的 Query Tensor
  • value:形状为 (B, S*, dim) 的 Value Tensor。如果 cache 为 None,则 S* 必须等于 S 并与 attention_mask 的形状匹配。如果 cache 不为 None,则 S* 可以是小于 S 的任何长度,计算出的 value 将被拼接(splice)到 cachecache_update_index 的位置。
  • key:可选的 key Tensor,形状为 (B, S*, dim)。如果 cacheNone,则 S* 必须等于 S 并与 attention_mask 的形状匹配。如果 cache 不为 None,则 S* 可以是小于 S 的任何长度,计算出的 value 将被拼接(splice)到 cachecache_update_index 的位置。
  • attention_mask:形状为 (B, T, S) 的布尔掩码。attention_mask 阻止对某些位置的注意力。布尔掩码指定哪些 query 元素可以关注哪些 key 元素,1 表示注意力,0 表示无注意力。对于缺失的批次维度和头维度,可以发生广播(Broadcasting)。
  • cache:一个稠密的 float Tensor。键/值缓存,形状为 [B, 2, S, num_heads, key_dims],其中 S 必须与 attention_mask 的形状一致。此参数旨在用于生成期间,以避免重新计算中间状态。
  • cache_update_index:一个 int 或 int Tensor,用于更新 cache 的索引(通常是在运行生成时正在处理的当前 token 的索引)。如果在设置了 cache 的同时 cache_update_index=None,则缓存不会更新。
  • training:一个布尔值,指示该层应在训练模式还是推理模式下运行。

返回

一个 (attention_output, cache) 元组。attention_output 是计算结果,形状为 (B, T, dim),其中 T 表示目标序列的形状,如果 output_shapeNone,则 dim 是 query 输入的最后一个维度。否则,多头输出将被投影到 output_shape 指定的形状。cache 是更新后的缓存。