StartEndPacker
类keras_nlp.layers.StartEndPacker(
sequence_length,
start_value=None,
end_value=None,
pad_value=None,
return_padding_mask=False,
name=None,
**kwargs
)
在序列中添加开始和结束标记并填充到固定长度。
此层在为翻译等任务分词输入时很有用,其中每个序列都应包含开始和结束标记。它应该在分词之后调用。该层将首先修剪输入以适合,然后添加开始/结束标记,最后根据需要填充到 sequence_length
。
输入数据应作为张量、tf.RaggedTensor
或列表传递。对于批量输入,输入应为列表的列表或等级为二的张量。对于非批量输入,每个元素应为列表或等级为一的张量。
参数
None
,则不会添加开始值。None
,则不会添加结束值。None
,则会根据输入张量的数据类型添加 0 或 ""。pad_value
的所有位置的布尔型填充掩码。调用参数
tf.Tensor
、tf.RaggedTensor
或 Python 字符串列表。sequence_length
。False
以不对此输入追加开始值。False
以不对此输入追加结束值。示例
非批量输入 (int)。
>>> inputs = [5, 6, 7]
>>> start_end_packer = keras_nlp.layers.StartEndPacker(
... sequence_length=7, start_value=1, end_value=2,
... )
>>> outputs = start_end_packer(inputs)
>>> np.array(outputs)
array([1, 5, 6, 7, 2, 0, 0], dtype=int32)
批量输入 (int)。
>>> inputs = [[5, 6, 7], [8, 9, 10, 11, 12, 13, 14]]
>>> start_end_packer = keras_nlp.layers.StartEndPacker(
... sequence_length=6, start_value=1, end_value=2,
... )
>>> outputs = start_end_packer(inputs)
>>> np.array(outputs)
array([[ 1, 5, 6, 7, 2, 0],
[ 1, 8, 9, 10, 11, 2]], dtype=int32)
非批量输入 (str)。
>>> inputs = tf.constant(["this", "is", "fun"])
>>> start_end_packer = keras_nlp.layers.StartEndPacker(
... sequence_length=6, start_value="<s>", end_value="</s>",
... pad_value="<pad>"
... )
>>> outputs = start_end_packer(inputs)
>>> np.array(outputs).astype("U")
array(['<s>', 'this', 'is', 'fun', '</s>', '<pad>'], dtype='<U5')
批量输入 (str)。
>>> inputs = tf.ragged.constant([["this", "is", "fun"], ["awesome"]])
>>> start_end_packer = keras_nlp.layers.StartEndPacker(
... sequence_length=6, start_value="<s>", end_value="</s>",
... pad_value="<pad>"
... )
>>> outputs = start_end_packer(inputs)
>>> np.array(outputs).astype("U")
array([['<s>', 'this', 'is', 'fun', '</s>', '<pad>'],
['<s>', 'awesome', '</s>', '<pad>', '<pad>', '<pad>']], dtype='<U7')
多个开始标记。
>>> inputs = tf.ragged.constant([["this", "is", "fun"], ["awesome"]])
>>> start_end_packer = keras_nlp.layers.StartEndPacker(
... sequence_length=6, start_value=["</s>", "<s>"], end_value="</s>",
... pad_value="<pad>"
... )
>>> outputs = start_end_packer(inputs)
>>> np.array(outputs).astype("U")
array([['</s>', '<s>', 'this', 'is', 'fun', '</s>'],
['</s>', '<s>', 'awesome', '</s>', '<pad>', '<pad>']], dtype='<U7')