Keras 3 API 文档 / KerasNLP / 预处理层 / 随机交换层

随机交换层

[来源]

RandomSwap

keras_nlp.layers.RandomSwap(
    rate,
    max_swaps=None,
    skip_list=None,
    skip_fn=None,
    skip_py_fn=None,
    seed=None,
    name=None,
    dtype="int32",
    **kwargs
)

通过随机交换词语来增强输入。

当您需要使用交换增强来生成新数据时,此层非常有用,如论文 [EDA: Easy Data Augmentation Techniques for Boosting Performance on Text Classification Tasks] (https://arxiv.org/pdf/1901.11196.pdf) 中所述。该层期望输入预先拆分为令牌级别输入。这允许控制增强级别,您可以按字符拆分进行字符级交换,或按单词进行词级交换。

输入数据应作为张量、tf.RaggedTensor 或列表传递。对于批量输入,输入应为列表的列表或秩二张量。对于非批量输入,每个元素应为列表或秩一张量。

参数

  • rate: 给定令牌被选为与另一个随机令牌交换的概率。
  • max_swaps: 要执行的最大交换次数。
  • skip_list: 不应被视为删除候选者的令牌值列表。
  • skip_fn: 一个函数,它接收一个标量张量令牌作为输入,并返回一个标量张量 True/False 值作为输出。True 值表示该令牌不应被视为删除候选者。此函数必须是可追踪的 - 它应该由 tensorflow 操作组成。
  • skip_py_fn: 一个函数,它接收一个 Python 令牌值作为输入,并返回 TrueFalse 作为输出。True 值表示不应被视为删除候选者。与 skip_fn 参数不同,此参数不必是可追踪的 - 它可以是任何 Python 函数。
  • seed: 随机数生成器的种子。

示例

词级用法。

>>> keras.utils.set_random_seed(1337)
>>> x = ["Hey I like", "Keras and Tensorflow"]
>>> x = list(map(lambda x: x.split(), x))
>>> augmenter = keras_nlp.layers.RandomSwap(rate=0.4, seed=42)
>>> y = augmenter(x)
>>> list(map(lambda y: " ".join(y), y))
['like I Hey', 'and Keras Tensorflow']

字符级用法。

>>> keras.utils.set_random_seed(1337)
>>> x = ["Hey Dude", "Speed Up"]
>>> x = list(map(lambda x: list(x), x))
>>> augmenter = keras_nlp.layers.RandomSwap(rate=0.4, seed=42)
>>> y = augmenter(x)
>>> list(map(lambda y: "".join(y), y))
['deD yuHe', 'SUede pp']

使用 skip_list。

>>> keras.utils.set_random_seed(1337)
>>> x = ["Hey I like", "Keras and Tensorflow"]
>>> x = list(map(lambda x: x.split(), x))
>>> augmenter = keras_nlp.layers.RandomSwap(rate=0.4,
...     skip_list=["Keras"], seed=42)
>>> y = augmenter(x)
>>> list(map(lambda y: " ".join(y), y))
['like I Hey', 'Keras and Tensorflow']

使用 skip_fn。

>>> def skip_fn(word):
...     return tf.strings.regex_full_match(word, r"[I, a].*")
>>> keras.utils.set_random_seed(1337)
>>> x = ["Hey I like", "Keras and Tensorflow"]
>>> x = list(map(lambda x: x.split(), x))
>>> augmenter = keras_nlp.layers.RandomSwap(rate=0.9, max_swaps=3,
...     skip_fn=skip_fn, seed=11)
>>> y = augmenter(x)
>>> list(map(lambda y: " ".join(y), y))
['like I Hey', 'Keras and Tensorflow']

使用 skip_py_fn。

>>> def skip_py_fn(word):
...     return len(word) < 4
>>> keras.utils.set_random_seed(1337)
>>> x = ["He was drifting along", "With the wind"]
>>> x = list(map(lambda x: x.split(), x))
>>> augmenter = keras_nlp.layers.RandomSwap(rate=0.8, max_swaps=2,
...     skip_py_fn=skip_py_fn, seed=15)
>>> y = augmenter(x)
>>> list(map(lambda y: " ".join(y), y))
['He was along drifting', 'wind the With']