MultiSegmentPacker
类keras_hub.layers.MultiSegmentPacker(
sequence_length,
start_value,
end_value,
sep_value=None,
pad_value=None,
truncate="round_robin",
**kwargs
)
将多个序列打包成单个固定宽度的模型输入。
此层将多个输入序列打包成一个包含开始和结束分隔符的单个固定宽度序列,形成适合 BERT 和类似 BERT 模型的分类任务的密集输入。
接收一个 token 片段元组作为输入。每个元组元素都应该包含一个片段的 tokens,作为张量、tf.RaggedTensor
或列表传递。对于批量输入,片段元组中的每个元素都应该是一个列表的列表或一个秩为 2 的张量。对于非批量输入,每个元素都应该是一个列表或秩为 1 的张量。
该层将按如下方式处理输入: - 根据 truncate
策略截断所有输入片段以适合 sequence_length
。 - 连接所有输入片段,在整个序列的开头添加单个 start_value
,并在每个片段的末尾添加多个 end_value
。 - 使用 pad_tokens
将生成的序列填充到 sequence_length
。 - 计算一个单独的“片段 ID”张量,其整数类型和形状与打包的 token 输出相同,其中每个整数索引表示 token 所来的片段。start_value
的片段 ID 始终为 0,每个 end_value
的片段 ID 为其前面的片段。
参数
None
,则使用 end_value
。数据类型必须与输入到层的张量的数据类型匹配。"round_robin"
或 "waterfall"
"round_robin"
:可用的空间每次分配一个 token,以循环方式分配给仍然需要一些空间的输入,直到达到限制。"waterfall"
:预算的分配使用“瀑布”算法完成,该算法以从左到右的方式分配配额,并填充存储桶,直到预算用完。它支持任意数量的片段。返回值
一个包含两个元素的元组。第一个是密集的、打包的 token 序列。第二个是相同形状的整数张量,包含片段 ID。
示例
打包单个输入以进行分类。
>>> seq1 = [1, 2, 3, 4]
>>> packer = keras_hub.layers.MultiSegmentPacker(
... sequence_length=8, start_value=101, end_value=102
... )
>>> token_ids, segment_ids = packer((seq1,))
>>> np.array(token_ids)
array([101, 1, 2, 3, 4, 102, 0, 0], dtype=int32)
>>> np.array(segment_ids)
array([0, 0, 0, 0, 0, 0, 0, 0], dtype=int32)
打包多个输入以进行分类。
>>> seq1 = [1, 2, 3, 4]
>>> seq2 = [11, 12, 13, 14]
>>> packer = keras_hub.layers.MultiSegmentPacker(
... sequence_length=8, start_value=101, end_value=102
... )
>>> token_ids, segment_ids = packer((seq1, seq2))
>>> np.array(token_ids)
array([101, 1, 2, 3, 102, 11, 12, 102], dtype=int32)
>>> np.array(segment_ids)
array([0, 0, 0, 0, 0, 1, 1, 1], dtype=int32)
打包多个输入以进行分类,并使用不同的分隔符 token。
>>> seq1 = [1, 2, 3, 4]
>>> seq2 = [11, 12, 13, 14]
>>> packer = keras_hub.layers.MultiSegmentPacker(
... sequence_length=8,
... start_value=101,
... end_value=102,
... sep_value=[102, 102],
... )
>>> token_ids, segment_ids = packer((seq1, seq2))
>>> np.array(token_ids)
array([101, 1, 2, 102, 102, 11, 12, 102], dtype=int32)
>>> np.array(segment_ids)
array([0, 0, 0, 0, 0, 1, 1, 1], dtype=int32)
参考文献