Keras 3 API 文档 / KerasHub / 分词器 / SentencePieceTokenizer

SentencePieceTokenizer

[来源]

SentencePieceTokenizer

keras_hub.tokenizers.SentencePieceTokenizer(
    proto=None,
    sequence_length=None,
    dtype="int32",
    add_bos=False,
    add_eos=False,
    **kwargs
)

一个 SentencePiece 分词层。

此层提供了 SentencePiece 分词的实现,如 SentencePiece 论文SentencePiece 包 中所述。分词将在 Tensorflow 图表中完全运行,并且可以保存到 keras.Model 中。

默认情况下,该层将输出一个 tf.RaggedTensor,其中输出的最后一个维度在空格分割和子词分词后是不规则的。如果设置了 sequence_length,则该层将输出一个密集的 tf.Tensor,其中所有输入都已填充或截断到 sequence_length。输出数据类型可以通过 dtype 参数控制,该参数应为整数或字符串类型。

参数

  • proto: SentencePiece proto 文件的 string 路径,或包含序列化 SentencePiece proto 的 bytes 对象。有关格式的更多详细信息,请参阅 SentencePiece 存储库
  • sequence_length: 如果设置,输出将转换为密集张量并填充/修剪,以便所有输出都为 sequence_length
  • add_bos: 在结果中添加句子开始标记。
  • add_eos: 在结果中添加句子结束标记。如果输出超过指定的 sequence_length,则标记将始终被截断。

参考文献

示例

从字节。

def train_sentence_piece_bytes(ds, size):
    bytes_io = io.BytesIO()
    sentencepiece.SentencePieceTrainer.train(
        sentence_iterator=ds.as_numpy_iterator(),
        model_writer=bytes_io,
        vocab_size=size,
    )
    return bytes_io.getvalue()

# Train a sentencepiece proto.
ds = tf.data.Dataset.from_tensor_slices(["the quick brown fox."])
proto = train_sentence_piece_bytes(ds, 20)
# Tokenize inputs.
tokenizer = keras_hub.tokenizers.SentencePieceTokenizer(proto=proto)
ds = ds.map(tokenizer)

从文件。

def train_sentence_piece_file(ds, path, size):
    with open(path, "wb") as model_file:
        sentencepiece.SentencePieceTrainer.train(
            sentence_iterator=ds.as_numpy_iterator(),
            model_writer=model_file,
            vocab_size=size,
        )

# Train a sentencepiece proto.
ds = tf.data.Dataset.from_tensor_slices(["the quick brown fox."])
proto = train_sentence_piece_file(ds, "model.spm", 20)
# Tokenize inputs.
tokenizer = keras_hub.tokenizers.SentencePieceTokenizer(proto="model.spm")
ds = ds.map(tokenizer)

[来源]

tokenize 方法

SentencePieceTokenizer.tokenize(inputs)

将输入字符串张量转换为输出标记。

参数

  • inputs: 输入张量,或输入张量的字典/列表/元组。
  • *args: 其他位置参数。
  • **kwargs: 其他关键字参数。

[来源]

detokenize 方法

SentencePieceTokenizer.detokenize(inputs)

将标记转换回字符串。

参数

  • inputs: 输入张量,或输入张量的字典/列表/元组。
  • *args: 其他位置参数。
  • **kwargs: 其他关键字参数。

[来源]

get_vocabulary 方法

SentencePieceTokenizer.get_vocabulary()

获取分词器词汇表。


[来源]

vocabulary_size 方法

SentencePieceTokenizer.vocabulary_size()

获取分词器词汇表的大小。


[来源]

token_to_id 方法

SentencePieceTokenizer.token_to_id(token)

将字符串标记转换为整数 ID。


[来源]

id_to_token 方法

SentencePieceTokenizer.id_to_token(id)

将整数 ID 转换为字符串标记。