Keras 2 API 文档 / 模型 API / 模型训练 API

模型训练 API

[源代码]

compile 方法

Model.compile(
    optimizer="rmsprop",
    loss=None,
    metrics=None,
    loss_weights=None,
    weighted_metrics=None,
    run_eagerly=None,
    steps_per_execution=None,
    jit_compile=None,
    pss_evaluation_shards=0,
    **kwargs
)

配置模型以进行训练。

示例

model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),
              loss=tf.keras.losses.BinaryCrossentropy(),
              metrics=[tf.keras.metrics.BinaryAccuracy(),
                       tf.keras.metrics.FalseNegatives()])

参数

  • optimizer: String(优化器名称)或优化器实例。请参阅 tf.keras.optimizers
  • loss: 损失函数。可以是字符串(损失函数名称)或 tf.keras.losses.Loss 实例。请参阅 tf.keras.losses。损失函数是任何具有签名 loss = fn(y_true, y_pred) 的可调用对象,其中 y_true 是真实值,y_pred 是模型的预测值。y_true 的形状应为 (batch_size, d0, .. dN)(稀疏损失函数如 sparse categorical crossentropy 除外,它期望的整数数组形状为 (batch_size, d0, .. dN-1))。y_pred 的形状应为 (batch_size, d0, .. dN)。损失函数应返回一个 float 张量。如果使用了自定义的 Loss 实例且 reduction 设置为 None,则返回值形状为 (batch_size, d0, .. dN-1),即每个样本或每个时间步的损失值;否则,它是一个标量。如果模型有多个输出,您可以通过传递字典或损失列表来为每个输出使用不同的损失。模型将要最小化的损失值将是所有单个损失的总和,除非指定了 loss_weights
  • metrics: 在训练和测试期间由模型评估的指标列表。其中每个都可以是字符串(内置函数名称)、函数或 tf.keras.metrics.Metric 实例。请参阅 tf.keras.metrics。通常你会使用 metrics=['accuracy']。函数是任何具有签名 result = fn(y_true, y_pred) 的可调用对象。要为多输出模型的不同输出指定不同的指标,您也可以传递一个字典,例如 metrics={'output_a':'accuracy', 'output_b':['accuracy', 'mse']}。您还可以传递列表来为每个输出指定一个指标或一系列指标,例如 metrics=[['accuracy'], ['accuracy', 'mse']]metrics=['accuracy', ['accuracy', 'mse']]。当你传递字符串 'accuracy' 或 'acc' 时,我们会根据目标和模型输出的形状,将其转换为 tf.keras.metrics.BinaryAccuracytf.keras.metrics.CategoricalAccuracytf.keras.metrics.SparseCategoricalAccuracy 之一。我们对字符串 'crossentropy' 和 'ce' 也进行类似的转换。这里传递的指标的计算不考虑样本权重;如果你希望考虑样本权重,你可以改用 weighted_metrics 参数来指定你的指标。
  • loss_weights: 可选的列表或字典,用于指定标量系数(Python 浮点数)以加权不同模型输出的损失贡献。模型将要最小化的损失值将是所有单个损失的加权和,权重为 loss_weights 系数。如果为列表,则期望与模型的输出有一一对应关系。如果为字典,则期望将输出名称(字符串)映射到标量系数。
  • weighted_metrics: 在训练和测试期间,由 sample_weightclass_weight 加权计算的指标列表。
  • run_eagerly: Bool。如果为 True,则此 Model 的逻辑不会被包装在 tf.function 中。建议将其保留为 None,除非您的 Model 无法在 tf.function 中运行。在使用 tf.distribute.experimental.ParameterServerStrategy 时,不支持 run_eagerly=True。默认为 False
  • steps_per_execution: Int 或 'auto'。在每次 tf.function 调用期间运行的批次数。如果设置为 "auto",Keras 将在运行时自动调整 steps_per_execution。在单个 tf.function 调用中运行多个批次可以极大地提高 TPU 上的性能,当与分布式策略(如 ParameterServerStrategy)一起使用时,或者与具有大量 Python 开销的小模型一起使用时。最多,每次执行将运行一个完整的 epoch。如果传入的数字大于 epoch 的大小,则执行将被截断为 epoch 的大小。请注意,如果 steps_per_execution 设置为 NCallback.on_batch_beginCallback.on_batch_end 方法将每 N 个批次调用一次(即在每次 tf.function 执行之前/之后)。默认为 1
  • jit_compile: 如果为 True,则使用 XLA 编译模型训练步骤。XLA 是一个用于机器学习的优化编译器。jit_compile 默认不启用。请注意,jit_compile=True 不一定适用于所有模型。有关支持的操作,请参阅 XLA 文档。另请参阅 XLA 已知问题 以获取更多详细信息。
  • pss_evaluation_shards: Integer 或 'auto'。仅用于 tf.distribute.ParameterServerStrategy 训练。此参数设置要拆分数据集的分片数量,以实现评估的精确访问保证,这意味着模型将对每个数据集元素精确访问一次,即使工作节点失败。数据集必须分片以确保不同的工作节点不处理相同的数据。分片数量应至少是工作节点数量,以获得良好的性能。值为 'auto' 将启用精确评估,并根据工作节点数量使用启发式方法确定分片数量。0 表示不提供访问保证。注意:自定义的 Model.test_step 实现将在进行精确评估时被忽略。默认为 0
  • **kwargs: 仅为向后兼容而支持的参数。

[源代码]

fit 方法

Model.fit(
    x=None,
    y=None,
    batch_size=None,
    epochs=1,
    verbose="auto",
    callbacks=None,
    validation_split=0.0,
    validation_data=None,
    shuffle=True,
    class_weight=None,
    sample_weight=None,
    initial_epoch=0,
    steps_per_epoch=None,
    validation_steps=None,
    validation_batch_size=None,
    validation_freq=1,
    max_queue_size=10,
    workers=1,
    use_multiprocessing=False,
)

将模型训练固定数量的 epoch(数据集迭代)。

参数

  • x: 输入数据。它可以是
    • 一个 Numpy 数组(或类数组),或一组数组(如果模型有多个输入)。
    • 一个 TensorFlow 张量,或一组张量(如果模型有多个输入)。
    • 一个映射输入名称到相应数组/张量的字典,如果模型有命名输入。
    • 一个 tf.data 数据集。应返回一个元组,其中包含 (inputs, targets)(inputs, targets, sample_weights)
    • 一个生成器或 keras.utils.Sequence,返回 (inputs, targets)(inputs, targets, sample_weights)
    • 一个 tf.keras.utils.experimental.DatasetCreator,它包装了一个可调用对象,该对象接受一个类型为 tf.distribute.InputContext 的单个参数,并返回一个 tf.data.Dataset。当用户希望为 Dataset 指定每个副本的批处理和分片逻辑时,应使用 DatasetCreator。有关更多信息,请参阅 tf.keras.utils.experimental.DatasetCreator 文档。有关迭代器类型(Dataset、generator、Sequence)的解包行为的更详细描述,请参阅下文。如果这些包含 sample_weights 作为第三个组件,请注意,样本权重适用于 compile() 中的 weighted_metrics 参数,但不适用于 metrics 参数。如果使用 tf.distribute.experimental.ParameterServerStrategy,则 x 只支持 DatasetCreator 类型。
  • y: 目标数据。与输入数据 x 类似,它可以是 Numpy 数组或 TensorFlow 张量。它应该与 x 一致(您不能有 Numpy 输入和张量目标,反之亦然)。如果 x 是一个数据集、生成器或 keras.utils.Sequence 实例,则不应指定 y(因为目标将从 x 中获取)。
  • batch_size: Integer 或 None。每次梯度更新的样本数。如果未指定,batch_size 将默认为 32。如果数据是数据集、生成器或 keras.utils.Sequence 实例的形式,请不要指定 batch_size(因为它们会生成批次)。
  • epochs: Integer。训练模型的 epoch 数。一个 epoch 是对提供的整个 xy 数据的一次迭代(除非 steps_per_epoch 标志设置为 None 以外的值)。请注意,与 initial_epoch 结合使用时,epochs 应理解为“最终 epoch”。模型不训练 epochs 指定的次数,而只是训练直到达到索引为 epochs 的 epoch。
  • verbose: 'auto', 0, 1, 或 2。详细程度模式。0 = 静默, 1 = 进度条, 2 = 每 epoch 一行。'auto' 在大多数情况下变为 1,但在与 ParameterServerStrategy 一起使用时变为 2。请注意,当输出到文件时,进度条不太有用,因此在非交互式运行时(例如,在生产环境中)建议使用 verbose=2。默认为 'auto'。
  • callbacks: keras.callbacks.Callback 实例列表。在训练期间应用的 callback 列表。请参阅 tf.keras.callbacks。请注意,tf.keras.callbacks.ProgbarLoggertf.keras.callbacks.History callback 会自动创建,无需传递给 model.fittf.keras.callbacks.ProgbarLogger 是否创建取决于 model.fitverbose 参数。目前不支持带批次调用的 callback 与 tf.distribute.experimental.ParameterServerStrategy,并建议用户改用带有适当 steps_per_epoch 值的 epoch 级调用。
  • validation_split: 0 到 1 之间的浮点数。用作验证数据的训练数据分数。模型将从训练数据中分离出此分数的数据,不在此数据上进行训练,并在每个 epoch 结束时在此数据上评估损失和任何模型指标。验证数据是从提供的 xy 数据中的最后一个样本中选择的,然后再进行混洗。当 x 是数据集、生成器或 keras.utils.Sequence 实例时,不支持此参数。如果同时提供了 validation_datavalidation_split,则 validation_data 将覆盖 validation_splitvalidation_split 尚未与 tf.distribute.experimental.ParameterServerStrategy 兼容。
  • validation_data: 在每个 epoch 结束时用于评估损失和任何模型指标的数据。模型不会在此数据上进行训练。因此,请注意,使用 validation_splitvalidation_data 提供的数据的验证损失不受噪声和 dropout 等正则化层的影响。validation_data 将覆盖 validation_splitvalidation_data 可以是:- 一个包含 Numpy 数组或张量的元组 (x_val, y_val)。- 一个包含 NumPy 数组的元组 (x_val, y_val, val_sample_weights)。- 一个 tf.data.Dataset。- 一个返回 (inputs, targets)(inputs, targets, sample_weights) 的 Python 生成器或 keras.utils.Sequencevalidation_data 尚未与 tf.distribute.experimental.ParameterServerStrategy 兼容。
  • shuffle: 布尔值(是否在每个 epoch 前混洗训练数据)或字符串('batch')。当 x 是生成器或 tf.data.Dataset 对象时,将忽略此参数。'batch' 是用于处理 HDF5 数据限制的特殊选项;它以批次大小的块进行混洗。当 steps_per_epoch 不为 None 时,没有效果。
  • class_weight: 可选的字典,将类索引(整数)映射到权重(浮点数)值,用于在训练期间加权损失函数。这有助于让模型“更加关注”来自代表不足类的样本。当指定 class_weight 且目标秩为 2 或更高时,y 必须是 one-hot 编码的,或者必须包含一个显式的最终维度 1 来表示稀疏类标签。
  • sample_weight: 可选的 Numpy 数组,包含训练样本的权重,用于在训练期间加权损失函数。您可以传递一个与输入样本长度相同的扁平(1D)Numpy 数组(权重与样本的 1:1 映射),或者在时间数据的情况下,您可以传递一个形状为 (samples, sequence_length) 的 2D 数组,以对每个样本的每个时间步应用不同的权重。当 x 是数据集、生成器或 keras.utils.Sequence 实例时,不支持此参数,而是将 sample_weights 作为 x 的第三个元素提供。请注意,样本权重不适用于 compile()metrics 参数指定的指标。要将样本权重应用于您的指标,您可以改用 compile() 中的 weighted_metrics 指定它们。
  • initial_epoch: Integer。开始训练的 epoch(对于恢复之前的训练运行很有用)。
  • steps_per_epoch: Integer 或 None。在宣布一个 epoch 完成并开始下一个 epoch 之前的总步数(样本批次)。当使用 TensorFlow 数据张量等输入张量进行训练时,默认的 None 等于您的数据集样本数除以批次大小,如果无法确定则为 1。如果 x 是 tf.data 数据集,并且 'steps_per_epoch' 为 None,则 epoch 将运行直到输入数据集耗尽。当传入一个无限重复的数据集时,您必须指定 steps_per_epoch 参数。如果 steps_per_epoch=-1,训练将无限期地运行,并使用无限重复的数据集。此参数不支持数组输入。使用 tf.distribute.experimental.ParameterServerStrategy 时:* 不支持 steps_per_epoch=None
  • validation_steps: 仅当提供了 validation_data 并且它是一个 tf.data 数据集时才相关。在每个 epoch 结束时进行验证时,要从中提取的总步数(样本批次),直到停止。如果 'validation_steps' 为 None,验证将运行直到 validation_data 数据集耗尽。在无限重复的数据集的情况下,它会陷入无限循环。如果指定了 'validation_steps' 并且只消耗了数据集的一部分,那么每次验证都将从数据集的开头开始。这确保了每次都使用相同的验证样本。
  • validation_batch_size: Integer 或 None。每个验证批次的样本数。如果未指定,则默认为 batch_size。如果数据是数据集、生成器或 keras.utils.Sequence 实例的形式,请不要指定 validation_batch_size(因为它们会生成批次)。
  • validation_freq: 仅当提供了验证数据时才相关。Integer 或 collections.abc.Container 实例(例如,列表、元组等)。如果为整数,则指定在执行新的验证运行之前要运行的训练 epoch 数,例如 validation_freq=2 每 2 个 epoch 运行一次验证。如果为容器,则指定运行验证的 epoch,例如 validation_freq=[1, 2, 10] 在第 1、2 和 10 个 epoch 结束时运行验证。
  • max_queue_size: Integer。仅用于生成器或 keras.utils.Sequence 输入。生成器队列的最大大小。如果未指定,max_queue_size 将默认为 10。
  • workers: Integer。仅用于生成器或 keras.utils.Sequence 输入。使用基于进程的线程时要启动的最大进程数。如果未指定,workers 将默认为 1。
  • use_multiprocessing: Boolean。仅用于生成器或 keras.utils.Sequence 输入。如果为 True,则使用基于进程的线程。如果未指定,use_multiprocessing 将默认为 False。请注意,由于此实现依赖于多进程,因此不应将不可腌制(non-pickleable)的参数传递给生成器,因为它们不容易传递给子进程。

迭代器类输入的解包行为:一个常见的模式是将 tf.data.Dataset、生成器或 tf.keras.utils.Sequence 传递给 fit 的 x 参数,该参数实际上会生成特征(x)以及可选的目标(y)和样本权重。TF-Keras 要求此类迭代器输出是明确的。迭代器应返回一个长度为 1、2 或 3 的元组,其中可选的第二个和第三个元素将分别用于 y 和 sample_weight。任何其他类型都将被包装在一个长度为 1 的元组中,有效地将所有内容视为 'x'。当生成字典时,它们仍应遵循顶层元组结构。例如 ({"x0": x0, "x1": x1}, y)。TF-Keras 不会尝试从单个字典的键中分离特征、目标和权重。一个值得注意的不支持的数据类型是命名元组(namedtuple)。原因是它同时表现得像一个有序数据类型(元组)和一个映射数据类型(字典)。因此,给定一个形式为:namedtuple("example_tuple", ["y", "x"]) 的命名元组,无法确定是将元素反序解释为 y、x 还是其他。更糟糕的是一个形式为:namedtuple("other_tuple", ["x", "y", "z"]) 的元组,它不清楚元组是打算解包为 x、y 和 sample_weight,还是作为单个元素传递给 x。结果是数据处理代码将引发 ValueError 如果遇到命名元组。(并附带纠正此问题的说明。)

返回

一个 History 对象。其 History.history 属性是一个记录,包含 successive epochs 的训练损失值和指标值,以及验证损失值和验证指标值(如果适用)。

引发

  • RuntimeError: 1. 如果模型从未编译过,或者 2. 如果 model.fit 被包装在 tf.function 中。
  • __ ValueError__: 如果提供的输入数据与模型期望的数据不匹配,或者输入数据为空。

[源代码]

evaluate 方法

Model.evaluate(
    x=None,
    y=None,
    batch_size=None,
    verbose="auto",
    sample_weight=None,
    steps=None,
    callbacks=None,
    max_queue_size=10,
    workers=1,
    use_multiprocessing=False,
    return_dict=False,
    **kwargs
)

在测试模式下返回模型的损失值和指标值。

计算以批次进行(请参阅 batch_size 参数)。

参数

  • x: 输入数据。它可以是
    • 一个 Numpy 数组(或类数组),或一组数组(如果模型有多个输入)。
    • 一个 TensorFlow 张量,或一组张量(如果模型有多个输入)。
    • 一个映射输入名称到相应数组/张量的字典,如果模型有命名输入。
    • 一个 tf.data 数据集。应返回一个元组,其中包含 (inputs, targets)(inputs, targets, sample_weights)
    • 一个生成器或 keras.utils.Sequence,返回 (inputs, targets)(inputs, targets, sample_weights)。有关迭代器类型(Dataset、generator、Sequence)的解包行为的更详细描述,请参阅 Model.fitUnpacking behavior for iterator-like inputs 部分。
  • y: 目标数据。与输入数据 x 类似,它可以是 Numpy 数组或 TensorFlow 张量。它应该与 x 一致(您不能有 Numpy 输入和张量目标,反之亦然)。如果 x 是一个数据集、生成器或 keras.utils.Sequence 实例,则不应指定 y(因为目标将从迭代器/数据集中获取)。
  • batch_size: Integer 或 None。计算批次的样本数。如果未指定,batch_size 将默认为 32。如果数据是数据集、生成器或 keras.utils.Sequence 实例的形式,请不要指定 batch_size(因为它们会生成批次)。
  • verbose: "auto", 0, 1, 或 2。详细程度模式。0 = 静默, 1 = 进度条, 2 = 单行。"auto" 在大多数情况下变为 1,而在与 ParameterServerStrategy 一起使用时变为 2。请注意,当输出到文件时,进度条不太有用,因此在非交互式运行时(例如,在生产环境中)建议 verbose=2。默认为 'auto'。
  • sample_weight: 可选的 Numpy 数组,包含测试样本的权重,用于加权损失函数。您可以传递一个与输入样本长度相同的扁平(1D)Numpy 数组(权重与样本的 1:1 映射),或者在时间数据的情况下,您可以传递一个形状为 (samples, sequence_length) 的 2D 数组,以对每个样本的每个时间步应用不同的权重。当 x 是数据集时,不支持此参数,而是将样本权重作为 x 的第三个元素提供。
  • steps: Integer 或 None。在宣布评估轮次完成之前的总步数(样本批次)。默认值 None 时忽略。如果 x 是 tf.data 数据集且 steps 为 None,则 'evaluate' 将运行直到数据集耗尽。此参数不支持数组输入。
  • callbacks: keras.callbacks.Callback 实例列表。在评估期间应用的 callback 列表。请参阅 callbacks
  • max_queue_size: Integer。仅用于生成器或 keras.utils.Sequence 输入。生成器队列的最大大小。如果未指定,max_queue_size 将默认为 10。
  • workers: Integer。仅用于生成器或 keras.utils.Sequence 输入。使用基于进程的线程时要启动的最大进程数。如果未指定,workers 将默认为 1。
  • use_multiprocessing: Boolean。仅用于生成器或 keras.utils.Sequence 输入。如果为 True,则使用基于进程的线程。如果未指定,use_multiprocessing 将默认为 False。请注意,由于此实现依赖于多进程,因此不应将不可腌制(non-pickleable)的参数传递给生成器,因为它们不容易传递给子进程。
  • return_dict: 如果为 True,则损失和指标结果将作为字典返回,其中每个键是指标的名称。如果为 False,则它们将作为列表返回。
  • **kwargs: 目前未使用。

请参阅 Model.fitUnpacking behavior for iterator-like inputs 的讨论。

返回

标量测试损失(如果模型只有一个输出且没有指标)或标量列表(如果模型有多个输出和/或指标)。属性 model.metrics_names 将为您提供标量输出的显示标签。

引发

  • RuntimeError: 如果 model.evaluate 被包装在 tf.function 中。

[源代码]

predict 方法

Model.predict(
    x,
    batch_size=None,
    verbose="auto",
    steps=None,
    callbacks=None,
    max_queue_size=10,
    workers=1,
    use_multiprocessing=False,
)

为输入样本生成输出预测。

计算以批次进行。此方法旨在处理大量输入的批次处理。它不适用于在遍历数据并处理少量输入的循环中使用。

对于适合单个批次的小量输入,直接使用 __call__() 以获得更快的执行速度,例如,model(x),或者(如果您有像 tf.keras.layers.BatchNormalization 这样的在推理期间行为不同的层)使用 model(x, training=False)。您可以在内部循环中使用 tf.function 将单个模型调用与额外的性能配对。如果您在模型调用后需要访问 numpy 数组值而不是张量,可以使用 tensor.numpy() 来获取 eager 张量的 numpy 数组值。

此外,请注意,测试损失不受噪声和 dropout 等正则化层的影响。

注意:有关 Model 方法 predict__call__ 之间区别的更多详细信息,请参阅 此 FAQ 条目

参数

  • x: 输入样本。它可以是
    • 一个 Numpy 数组(或类数组),或一组数组(如果模型有多个输入)。
    • 一个 TensorFlow 张量,或一组张量(如果模型有多个输入)。
    • 一个 tf.data 数据集。
    • 一个生成器或 keras.utils.Sequence 实例。有关迭代器类型(Dataset、generator、Sequence)的解包行为的更详细描述,请参阅 Model.fitUnpacking behavior for iterator-like inputs 部分。
  • batch_size: Integer 或 None。每个批次的样本数。如果未指定,batch_size 将默认为 32。如果数据是数据集、生成器或 keras.utils.Sequence 实例的形式,请不要指定 batch_size(因为它们会生成批次)。
  • verbose: "auto", 0, 1, 或 2。详细程度模式。0 = 静默, 1 = 进度条, 2 = 单行。"auto" 在大多数情况下变为 1,而在与 ParameterServerStrategy 一起使用时变为 2。请注意,当输出到文件时,进度条不太有用,因此在非交互式运行时(例如,在生产环境中)建议 verbose=2。默认为 'auto'。
  • steps: 在宣布预测轮次完成之前的总步数(样本批次)。默认值 None 时忽略。如果 x 是 tf.data 数据集且 steps 为 None,则 predict() 将运行直到输入数据集耗尽。
  • callbacks: keras.callbacks.Callback 实例列表。在预测期间应用的 callback 列表。请参阅 callbacks
  • max_queue_size: Integer。仅用于生成器或 keras.utils.Sequence 输入。生成器队列的最大大小。如果未指定,max_queue_size 将默认为 10。
  • workers: Integer。仅用于生成器或 keras.utils.Sequence 输入。使用基于进程的线程时要启动的最大进程数。如果未指定,workers 将默认为 1。
  • use_multiprocessing: Boolean。仅用于生成器或 keras.utils.Sequence 输入。如果为 True,则使用基于进程的线程。如果未指定,use_multiprocessing 将默认为 False。请注意,由于此实现依赖于多进程,因此不应将不可腌制(non-pickleable)的参数传递给生成器,因为它们不容易传递给子进程。

请参阅 Model.fitUnpacking behavior for iterator-like inputs 的讨论。请注意,Model.predict 使用的解释规则与 Model.fitModel.evaluate 相同,因此对于所有这三种方法,输入都必须是明确的。

返回

预测的 Numpy 数组。

引发

  • RuntimeError: 如果 model.predict 被包装在 tf.function 中。
  • ValueError: 如果提供的输入数据与模型的期望不匹配,或者如果一个有状态的模型接收的样本数不是批次大小的倍数。

[源代码]

train_on_batch 方法

Model.train_on_batch(
    x,
    y=None,
    sample_weight=None,
    class_weight=None,
    reset_metrics=True,
    return_dict=False,
)

对单批数据执行一次梯度更新。

参数

  • x: 输入数据。它可以是
    • 一个 Numpy 数组(或类数组),或一组数组(如果模型有多个输入)。
    • 一个 TensorFlow 张量,或一组张量(如果模型有多个输入)。
    • 一个映射输入名称到相应数组/张量的字典,如果模型有命名输入。
  • y: 目标数据。与输入数据 x 类似,它可以是 Numpy 数组或 TensorFlow 张量。
  • sample_weight: 可选数组,长度与 x 相同,包含应用于模型损失的每个样本的权重。在时间数据的情况下,您可以传递一个形状为 (samples, sequence_length) 的 2D 数组,以对每个样本的每个时间步应用不同的权重。
  • class_weight: 可选字典,将类索引(整数)映射到权重(浮点数),用于在训练期间应用于该类样本的模型损失。这有助于让模型“更加关注”来自代表不足类的样本。当指定 class_weight 且目标秩为 2 或更高时,y 必须是 one-hot 编码的,或者必须包含一个显式的最终维度 1 来表示稀疏类标签。
  • reset_metrics: 如果为 True,则返回的指标仅适用于此批次。如果为 False,则指标将在批次之间保持状态累积。
  • return_dict: 如果为 True,则损失和指标结果将作为字典返回,其中每个键是指标的名称。如果为 False,则它们将作为列表返回。

返回

标量训练损失(如果模型只有一个输出且没有指标)或标量列表(如果模型有多个输出和/或指标)。属性 model.metrics_names 将为您提供标量输出的显示标签。

引发

  • RuntimeError: 如果 model.train_on_batch 被包装在 tf.function 中。

[源代码]

test_on_batch 方法

Model.test_on_batch(
    x, y=None, sample_weight=None, reset_metrics=True, return_dict=False
)

在单批样本上测试模型。

参数

  • x: 输入数据。它可以是
    • 一个 Numpy 数组(或类数组),或一组数组(如果模型有多个输入)。
    • 一个 TensorFlow 张量,或一组张量(如果模型有多个输入)。
    • 一个映射输入名称到相应数组/张量的字典,如果模型有命名输入。
  • y: 目标数据。与输入数据 x 类似,它可以是 Numpy 数组或 TensorFlow 张量。它应该与 x 一致(您不能有 Numpy 输入和张量目标,反之亦然)。
  • sample_weight: 可选数组,长度与 x 相同,包含应用于模型损失的每个样本的权重。在时间数据的情况下,您可以传递一个形状为 (samples, sequence_length) 的 2D 数组,以对每个样本的每个时间步应用不同的权重。
  • reset_metrics: 如果为 True,则返回的指标仅适用于此批次。如果为 False,则指标将在批次之间保持状态累积。
  • return_dict: 如果为 True,则损失和指标结果将作为字典返回,其中每个键是指标的名称。如果为 False,则它们将作为列表返回。

返回

标量测试损失(如果模型只有一个输出且没有指标)或标量列表(如果模型有多个输出和/或指标)。属性 model.metrics_names 将为您提供标量输出的显示标签。

引发

  • RuntimeError: 如果 model.test_on_batch 被包装在 tf.function 中。

[源代码]

predict_on_batch 方法

Model.predict_on_batch(x)

返回单批样本的预测。

参数

  • x: 输入数据。它可以是
    • 一个 Numpy 数组(或类数组),或一组数组(如果模型有多个输入)。
    • 一个 TensorFlow 张量,或一组张量(如果模型有多个输入)。

返回

预测的 Numpy 数组。

引发

  • RuntimeError: 如果 model.predict_on_batch 被包装在 tf.function 中。

run_eagerly 属性

tf_keras.Model.run_eagerly

可设置的属性,指示模型是否应以 eager 模式运行。

以 eager 模式运行意味着您的模型将逐个步骤运行,就像 Python 代码一样。您的模型运行速度可能会变慢,但通过逐个调用层来调试模型应该会更容易。

默认情况下,我们将尝试将模型编译为静态图以提供最佳执行性能。

返回

布尔值,表示模型是否应以 eager 模式运行。