Keras 3 API 文档 / 层 API / 注意力层 / 多头注意力层

多头注意力层

[来源]

MultiHeadAttention

keras.layers.MultiHeadAttention(
    num_heads,
    key_dim,
    value_dim=None,
    dropout=0.0,
    use_bias=True,
    output_shape=None,
    attention_axes=None,
    kernel_initializer="glorot_uniform",
    bias_initializer="zeros",
    kernel_regularizer=None,
    bias_regularizer=None,
    activity_regularizer=None,
    kernel_constraint=None,
    bias_constraint=None,
    seed=None,
    **kwargs
)

多头注意力层。

这是多头注意力机制的实现,如论文“注意力是你所需要的一切”Vaswani 等人,2017 年 所述。如果querykeyvalue 相同,则这是自注意力机制。query 中的每个时间步都关注key 中的对应序列,并返回一个固定宽度的向量。

此层首先投影querykeyvalue。这些实际上是长度为 num_attention_heads 的张量列表,其中对应形状为 (batch_size, <query 维度>, key_dim)(batch_size, <key/value 维度>, key_dim)(batch_size, <key/value 维度>, value_dim)

然后,将 query 和 key 张量点积并缩放。对这些进行 softmax 以获得注意力概率。然后,这些概率对 value 张量进行插值,然后连接回单个张量。

最后,结果张量(其最后维度为 value_dim)可以进行线性投影并返回。

参数

  • num_heads: 注意力头的数量。
  • key_dim: query 和 key 每个注意力头的尺寸。
  • value_dim: value 每个注意力头的尺寸。
  • dropout: 丢弃概率。
  • use_bias: 布尔值,表示密集层是否使用偏差向量/矩阵。
  • output_shape: 输出张量的预期形状,除了批次和序列维度。如果未指定,则投影回查询特征维度(查询输入的最后一个维度)。
  • attention_axes: 应用注意力的轴。None 表示对所有轴应用注意力,但批次、头和特征除外。
  • kernel_initializer: 密集层内核的初始化器。
  • bias_initializer: 密集层偏差的初始化器。
  • kernel_regularizer: 密集层内核的正则化器。
  • bias_regularizer: 密集层偏差的正则化器。
  • activity_regularizer: 密集层活动的正则化器。
  • kernel_constraint: 密集层内核的约束。
  • bias_constraint: 密集层内核的约束。
  • seed: 可选整数,用于为丢弃层设置种子。

调用参数

  • query: 形状为 (B, T, dim) 的查询张量,其中 B 是批次大小,T 是目标序列长度,而 dim 是特征维度。
  • value: 形状为 (B, S, dim) 的值张量,其中 B 是批次大小,S 是源序列长度,而 dim 是特征维度。
  • key: 形状为 (B, S, dim) 的可选键张量。如果没有给出,则将使用 value 作为 keyvalue,这是最常见的情况。
  • attention_mask: 形状为 (B, T, S) 的布尔掩码,它阻止对某些位置的注意力。布尔掩码指定哪些查询元素可以关注哪些键元素,1 表示注意力,0 表示不注意。对于缺少的批次维度和头维度,可以进行广播。
  • return_attention_scores: 一个布尔值,用于指示输出是否应为 (attention_output, attention_scores)(如果为 True),或者为 attention_output(如果为 False)。默认为 False
  • training: Python 布尔值,指示层应该以训练模式(添加丢弃)还是推理模式(不添加丢弃)运行。它将使用父层/模型的训练模式,或者如果没有父层,则使用 False(推理)。
  • use_causal_mask: 一个布尔值,指示是否应用因果掩码以防止标记关注未来标记(例如,在解码器 Transformer 中使用)。

返回值

  • attention_output: 计算结果,形状为 (B, T, E),其中 T 用于目标序列形状,而 E 是查询输入的最后一个维度(如果 output_shapeNone)。否则,多头输出将投影到 output_shape 指定的形状。
  • attention_scores: (可选)对注意力轴的多头注意力系数。