Attention
类tf_keras.layers.Attention(use_scale=False, score_mode="dot", **kwargs)
点积注意力层,也称为 Luong 风格注意力。
输入是形状为 [batch_size, Tq, dim]
的 query
张量、形状为 [batch_size, Tv, dim]
的 value
张量和形状为 [batch_size, Tv, dim]
的 key
张量。计算步骤如下:
query
与 key
进行点积,计算形状为 [batch_size, Tq, Tv]
的分数:scores = tf.matmul(query, key, transpose_b=True)
。[batch_size, Tq, Tv]
的分布:distribution = tf.nn.softmax(scores)
。distribution
对 value
进行线性组合,生成形状为 [batch_size, Tq, dim]
的输出:return tf.matmul(distribution, value)
。参数
True
,将创建一个标量变量来缩放注意力分数。{"dot", "concat"}
之一。"dot"
指查询向量和键向量之间的点积。"concat"
指查询向量和键向量的连接的双曲正切。调用参数
[batch_size, Tq, dim]
的查询 Tensor
。[batch_size, Tv, dim]
的值 Tensor
。Tensor
,形状为 [batch_size, Tv, dim]
。如果未提供,则 key
和 value
都将使用 value
,这是最常见的情况。[batch_size, Tq]
的布尔掩码 Tensor
。如果提供,则在 mask==False
的位置输出将为零。[batch_size, Tv]
的布尔掩码 Tensor
。如果提供,将应用此掩码,以便在 mask==False
的位置上的值不对结果产生贡献。True
,则将注意力分数(在应用掩码和 softmax 后)作为额外的输出参数返回。True
。添加一个掩码,使得位置 i
不能注意到位置 j > i
。这阻止了信息从未来流向过去。默认为 False
。输出
Attention outputs of shape `[batch_size, Tq, dim]`.
[Optional] Attention scores after masking and softmax with shape
`[batch_size, Tq, Tv]`.
query
、value
和 key
的含义取决于应用。例如,在文本相似性任务中,query
是第一段文本的序列嵌入,而 value
是第二段文本的序列嵌入。key
通常与 value
是相同的张量。
下面是在 CNN+Attention 网络中使用 Attention
的代码示例
# Variable-length int sequences.
query_input = tf.keras.Input(shape=(None,), dtype='int32')
value_input = tf.keras.Input(shape=(None,), dtype='int32')
# Embedding lookup.
token_embedding = tf.keras.layers.Embedding(input_dim=1000, output_dim=64)
# Query embeddings of shape [batch_size, Tq, dimension].
query_embeddings = token_embedding(query_input)
# Value embeddings of shape [batch_size, Tv, dimension].
value_embeddings = token_embedding(value_input)
# CNN layer.
cnn_layer = tf.keras.layers.Conv1D(
filters=100,
kernel_size=4,
# Use 'same' padding so outputs have the same shape as inputs.
padding='same')
# Query encoding of shape [batch_size, Tq, filters].
query_seq_encoding = cnn_layer(query_embeddings)
# Value encoding of shape [batch_size, Tv, filters].
value_seq_encoding = cnn_layer(value_embeddings)
# Query-value attention of shape [batch_size, Tq, filters].
query_value_attention_seq = tf.keras.layers.Attention()(
[query_seq_encoding, value_seq_encoding])
# Reduce over the sequence axis to produce encodings of shape
# [batch_size, filters].
query_encoding = tf.keras.layers.GlobalAveragePooling1D()(
query_seq_encoding)
query_value_attention = tf.keras.layers.GlobalAveragePooling1D()(
query_value_attention_seq)
# Concatenate query and document encodings to produce a DNN input layer.
input_layer = tf.keras.layers.Concatenate()(
[query_encoding, query_value_attention])
# Add DNN layers, and create Model.
# ...