Keras 3 API 文档 / 模型 API / 模型训练 API

模型训练 API

[来源]

compile 方法

Model.compile(
    optimizer="rmsprop",
    loss=None,
    loss_weights=None,
    metrics=None,
    weighted_metrics=None,
    run_eagerly=False,
    steps_per_execution=1,
    jit_compile="auto",
    auto_scale_loss=True,
)

配置模型以进行训练。

示例

model.compile(
    optimizer=keras.optimizers.Adam(learning_rate=1e-3),
    loss=keras.losses.BinaryCrossentropy(),
    metrics=[
        keras.metrics.BinaryAccuracy(),
        keras.metrics.FalseNegatives(),
    ],
)

参数

  • optimizer: 字符串(优化器名称)或优化器实例。请参阅 keras.optimizers
  • loss: 损失函数。可以是字符串(损失函数名称),或 keras.losses.Loss 实例。请参阅 keras.losses。损失函数是任何带有签名 loss = fn(y_true, y_pred) 的可调用对象,其中 y_true 是真实值,y_pred 是模型的预测值。y_true 应具有形状 (batch_size, d0, .. dN)(除了稀疏损失函数(如稀疏分类交叉熵)的情况,它期望形状为 (batch_size, d0, .. dN-1) 的整数数组)。y_pred 应具有形状 (batch_size, d0, .. dN)。损失函数应返回一个浮点张量。
  • loss_weights: 可选列表或字典,指定标量系数(Python 浮点数)以加权不同模型输出的损失贡献。然后,模型将最小化的损失值为所有个体损失的加权总和,由 loss_weights 系数加权。如果为列表,则预计它与模型的输出具有 1:1 映射。如果为字典,则预计它将输出名称(字符串)映射到标量系数。
  • metrics: 模型在训练和测试期间评估的指标列表。每个指标可以是字符串(内置函数的名称)、函数或 keras.metrics.Metric 实例。请参阅 keras.metrics。通常,您将使用 metrics=['accuracy']。函数是任何带有签名 result = fn(y_true, _pred) 的可调用对象。要为多输出模型的不同输出指定不同的指标,您还可以传递字典,例如 metrics={'a':'accuracy', 'b':['accuracy', 'mse']}。您还可以传递列表以指定每个输出的指标或指标列表,例如 metrics=[['accuracy'], ['accuracy', 'mse']]metrics=['accuracy', ['accuracy', 'mse']]。当您传递字符串 'accuracy' 或 'acc' 时,我们会将其转换为 keras.metrics.BinaryAccuracykeras.metrics.CategoricalAccuracykeras.metrics.SparseCategoricalAccuracy 之一,具体取决于目标和模型输出的形状。对字符串 "crossentropy""ce" 也执行类似的转换。此处传递的指标在没有样本加权的情况下进行评估;如果您希望应用样本加权,则可以通过 weighted_metrics 参数指定您的指标。
  • weighted_metrics: 在训练和测试期间由 sample_weightclass_weight 评估和加权的指标列表。
  • run_eagerly: 布尔值。如果为 True,则此模型的前向传递永远不会被编译。建议在训练时将其保留为 False(以获得最佳性能),并在调试时将其设置为 True
  • steps_per_execution: 整数。在每次编译函数调用期间运行的批次数量。在单个编译函数调用中运行多个批次可以极大地提高 TPU 上的性能或具有大量 Python 开销的小型模型。最多,每次执行将运行一个完整的时期。如果传递的数字大于时期的尺寸,则执行将被截断为时期的尺寸。请注意,如果 steps_per_execution 设置为 N,则 Callback.on_batch_beginCallback.on_batch_end 方法将仅每 N 个批次调用一次(即在每次编译函数执行之前/之后)。不支持 PyTorch 后端。
  • jit_compile: 布尔值或 "auto"。是否在编译模型时使用 XLA 编译。对于 jaxtensorflow 后端,jit_compile="auto" 在模型支持它时启用 XLA 编译,否则禁用。对于 torch 后端,"auto" 将默认使用急切执行,jit_compile=True 将使用 "inductor" 后端运行 torch.compile
  • auto_scale_loss: 布尔值。如果为 True 且模型数据类型策略为 "mixed_float16",则传递的优化器将自动包装在 LossScaleOptimizer 中,这将动态缩放损失以防止下溢。

[来源]

fit 方法

Model.fit(
    x=None,
    y=None,
    batch_size=None,
    epochs=1,
    verbose="auto",
    callbacks=None,
    validation_split=0.0,
    validation_data=None,
    shuffle=True,
    class_weight=None,
    sample_weight=None,
    initial_epoch=0,
    steps_per_epoch=None,
    validation_steps=None,
    validation_batch_size=None,
    validation_freq=1,
)

训练模型固定数量的时期(数据集迭代)。

参数

  • x: 输入数据。它可以是
    • NumPy 数组(或类数组),或数组列表(如果模型具有多个输入)。
    • 张量,或张量列表(如果模型具有多个输入)。
    • 如果模型具有命名输入,则将输入名称映射到对应数组/张量的字典。
    • tf.data.Dataset。应返回 (inputs, targets)(inputs, targets, sample_weights) 的元组。
    • keras.utils.PyDataset 返回 (inputs, targets)(inputs, targets, sample_weights)
  • y: 目标数据。与输入数据 x 相似,它可以是 NumPy 数组或后端本机张量。如果 x 是数据集、生成器或 keras.utils.PyDataset 实例,则不应指定 y(因为目标将从 x 中获取)。
  • batch_size: 整数或 None。每次梯度更新的样本数量。如果未指定,batch_size 将默认为 32。如果您的数据采用数据集、生成器或 keras.utils.PyDataset 实例的形式,请不要指定 batch_size(因为它们会生成批次)。
  • epochs: 整数。训练模型的时期数量。时期是对提供的整个 xy 数据的迭代(除非 steps_per_epoch 标志设置为非 None)。请注意,与 initial_epoch 结合使用时,epochs 应理解为“最终时期”。模型不会训练由 epochs 给出的迭代次数,而只是训练到索引为 epochs 的时期为止。
  • verbose: "auto"、0、1 或 2。详细模式。0 = 静默,1 = 进度条,2 = 每个时期一行。"auto" 在大多数情况下变为 1。请注意,进度条在记录到文件时不是特别有用,因此在非交互式运行时(例如,在生产环境中)推荐使用 verbose=2。默认为 "auto"
  • callbacks: keras.callbacks.Callback 实例列表。在训练期间应用的回调列表。请参阅 keras.callbacks。注意 keras.callbacks.ProgbarLoggerkeras.callbacks.History 回调是自动创建的,不需要传递给 model.fit()。根据 model.fit() 中的 verbose 参数,创建或不创建 keras.callbacks.ProgbarLogger
  • validation_split: 0 到 1 之间的浮点数。用作验证数据的训练数据比例。模型将分离这部分训练数据,不会在上面进行训练,并且将在每个时期结束时评估该数据上的损失和任何模型指标。验证数据是从提供的 xy 数据中的最后样本中选择的,在洗牌之前。当 x 是数据集、生成器或 keras.utils.PyDataset 实例时,不支持此参数。如果同时提供 validation_datavalidation_split,则 validation_data 将覆盖 validation_split
  • validation_data: 用于在每个时期结束时评估损失和任何模型指标的数据。模型不会在此数据上进行训练。因此,请注意,使用 validation_splitvalidation_data 提供的数据的验证损失不受噪声和丢弃等正则化层的影響。validation_data 将覆盖 validation_split。它可以是
    • NumPy 数组或张量的元组 (x_val, y_val)
    • NumPy 数组的元组 (x_val, y_val, val_sample_weights)
    • tf.data.Dataset.
    • 返回 (inputs, targets)(inputs, targets, sample_weights) 的 Python 生成器或 keras.utils.PyDataset
  • shuffle: 布尔值,指示是否在每个时期之前洗牌训练数据。当 x 是生成器或 tf.data.Dataset 时,此参数将被忽略。
  • class_weight: 可选字典,将类索引(整数)映射到权重(浮点数)值,用于加权损失函数(仅在训练期间)。这对于告诉模型“更多关注”来自代表性不足的类的样本很有用。当指定 class_weight 且目标的秩为 2 或更大时,y 必须是独热编码,或者对于稀疏类标签,必须包含 1 的显式最终维度。
  • sample_weight: 训练样本的可选 NumPy 权重数组,用于加权损失函数(仅在训练期间)。您可以传递与输入样本长度相同的扁平(一维)NumPy 数组(权重和样本之间一对一映射),或者在时间数据的情况下,您可以传递形状为 (samples, sequence_length) 的二维数组,以对每个样本的每个时间步应用不同的权重。当 x 是数据集、生成器或 keras.utils.PyDataset 实例时,不支持此参数,而是将样本权重作为 x 的第三个元素提供。请注意,样本加权不适用于通过 compile() 中的 metrics 参数指定的指标。要将样本加权应用于您的指标,您可以通过 compile() 中的 weighted_metrics 指定它们。
  • initial_epoch: 整数。开始训练的时期(用于恢复以前的训练运行)。
  • steps_per_epoch: 整数或 None。声明一个 epoch 完成并开始下一个 epoch 之前样本批次的总步数。当使用后端原生张量等输入张量进行训练时,默认的 None 等于数据集中的样本数量除以批次大小,或者如果无法确定则为 1。如果 x 是一个 tf.data.Dataset,并且 steps_per_epochNone,则 epoch 将一直运行到输入数据集耗尽。当传递无限循环的数据集时,必须指定 steps_per_epoch 参数。如果 steps_per_epoch=-1,则训练将使用无限循环的数据集无限期运行。
  • validation_steps: 仅在提供 validation_data 时才相关。在每个 epoch 结束时执行验证时,在停止之前要提取的样本批次的总步数。如果 validation_stepsNone,则验证将一直运行到 validation_data 数据集耗尽。对于无限循环的数据集,它将进入无限循环。如果指定了 validation_steps 并且只消耗了数据集的一部分,则评估将在每个 epoch 从数据集的开头开始。这确保每次使用相同的验证样本。
  • validation_batch_size: 整数或 None。每个验证批次的样本数量。如果未指定,将默认为 batch_size。如果您的数据是数据集或 keras.utils.PyDataset 实例的形式,请勿指定 validation_batch_size(因为它们会生成批次)。
  • validation_freq: 仅在提供验证数据时才相关。指定在执行新的验证运行之前要运行多少个训练 epoch,例如 validation_freq=2 每 2 个 epoch 运行一次验证。

迭代器式输入的解包行为:一种常见的模式是将类似迭代器的对象(如 tf.data.Datasetkeras.utils.PyDataset)传递给 fit(),这实际上将不仅产生特征 (x) 而且还会可选地产生目标 (y) 和样本权重 (sample_weight)。Keras 要求此类迭代器式对象的输出是明确的。迭代器应返回长度为 1、2 或 3 的元组,其中可选的第二和第三个元素将分别用于 ysample_weight。任何其他类型都将包装在长度为一的元组中,有效地将所有内容视为 x。当产生字典时,它们仍然应该遵守顶层元组结构,例如 ({"x0": x0, "x1": x1}, y)。Keras 不会尝试从单个字典的键中分离特征、目标和权重。一个值得注意的不受支持的数据类型是 namedtuple。原因是它既像有序数据类型 (tuple) 又像映射数据类型 (dict)。因此,对于以下形式的命名元组:namedtuple("example_tuple", ["y", "x"]),在解释值时是否反转元素的顺序是不确定的。更糟糕的是,以下形式的元组:namedtuple("other_tuple", ["x", "y", "z"]),其中不清楚元组是否意图被解包到 xysample_weight 中,或者作为单个元素传递给 x

返回值

一个 History 对象。它的 History.history 属性记录了连续 epoch 的训练损失值和指标值,以及验证损失值和验证指标值(如果适用)。


[来源]

evaluate 方法

Model.evaluate(
    x=None,
    y=None,
    batch_size=None,
    verbose="auto",
    sample_weight=None,
    steps=None,
    callbacks=None,
    return_dict=False,
    **kwargs
)

返回模型在测试模式下的损失值和指标值。

计算是在批次中完成的(参见 batch_size 参数)。

参数

  • x: 输入数据。它可以是
    • NumPy 数组(或类数组),或数组列表(如果模型具有多个输入)。
    • 张量,或张量列表(如果模型具有多个输入)。
    • 如果模型具有命名输入,则将输入名称映射到对应数组/张量的字典。
    • tf.data.Dataset。应返回 (inputs, targets)(inputs, targets, sample_weights) 的元组。
    • 一个生成器或 keras.utils.PyDataset,返回 (inputs, targets)(inputs, targets, sample_weights)
  • y: 目标数据。与输入数据 x 一样,它可以是 NumPy 数组或后端原生张量。如果 x 是一个 tf.data.Datasetkeras.utils.PyDataset 实例,则不应指定 y(因为目标将从迭代器/数据集中获取)。
  • batch_size: 整数或 None。每个计算批次的样本数量。如果未指定,batch_size 将默认为 32。如果您的数据是数据集、生成器或 keras.utils.PyDataset 实例的形式,请勿指定 batch_size(因为它们会生成批次)。
  • verbose: "auto"、0、1 或 2。详细程度模式。0 = 静默,1 = 进度条,2 = 单行。"auto" 在大多数情况下变为 1。请注意,进度条在记录到文件时不是特别有用,因此在非交互式运行时(例如在生产环境中)建议使用 verbose=2。默认为 "auto"
  • sample_weight: 可选的 NumPy 权重数组,用于测试样本,用于对损失函数进行加权。您可以传递与输入样本长度相同的扁平 (1D) NumPy 数组(权重与样本之间一对一映射),或者在时间数据的情况下,您可以传递形状为 (samples, sequence_length) 的二维数组,以对每个样本的每个时间步应用不同的权重。当 x 是数据集时,不支持此参数,而是将样本权重作为 x 的第三个元素传递。
  • steps: 整数或 None。声明评估轮次完成之前样本批次的总步数。使用默认值 None 时忽略。如果 x 是一个 tf.data.Dataset 并且 stepsNone,则评估将一直运行到数据集耗尽。
  • callbacks: keras.callbacks.Callback 实例的列表。评估期间要应用的回调列表。
  • return_dict: 如果为 True,则损失和指标结果将作为字典返回,每个键都是指标的名称。如果为 False,则它们将作为列表返回。

返回值

标量测试损失(如果模型只有一个输出且没有指标)或标量列表(如果模型有多个输出和/或指标)。属性 model.metrics_names 将为您提供标量输出的显示标签。


[来源]

predict 方法

Model.predict(x, batch_size=None, verbose="auto", steps=None, callbacks=None)

为输入样本生成输出预测。

计算是在批次中完成的。此方法旨在用于对大量输入进行批处理。它不打算在循环内使用,这些循环遍历您的数据并一次处理少量输入。

对于适合在一个批次中的少量输入,直接使用 __call__() 进行更快的执行,例如 model(x),或者如果您的层(如 BatchNormalization)在推理期间的行为不同,则使用 model(x, training=False)

注意:有关 Model 方法 predict()__call__() 之间区别的更多详细信息,请参阅 此常见问题解答条目

参数

  • x: 输入样本。它可以是
    • NumPy 数组(或类数组),或数组列表(如果模型具有多个输入)。
    • 张量,或张量列表(如果模型具有多个输入)。
    • tf.data.Dataset.
    • 一个 keras.utils.PyDataset 实例。
  • batch_size: 整数或 None。每个批次的样本数量。如果未指定,batch_size 将默认为 32。如果您的数据是数据集、生成器或 keras.utils.PyDataset 实例的形式,请勿指定 batch_size(因为它们会生成批次)。
  • verbose: "auto"、0、1 或 2。详细程度模式。0 = 静默,1 = 进度条,2 = 单行。"auto" 在大多数情况下变为 1。请注意,进度条在记录到文件时不是特别有用,因此在非交互式运行时(例如在生产环境中)建议使用 verbose=2。默认为 "auto"
  • steps: 声明预测轮次完成之前样本批次的总步数。使用默认值 None 时忽略。如果 x 是一个 tf.data.Dataset 并且 stepsNone,则 predict() 将一直运行到输入数据集耗尽。
  • callbacks: keras.callbacks.Callback 实例的列表。预测期间要应用的回调列表。

返回值

预测的 NumPy 数组。


[来源]

train_on_batch 方法

Model.train_on_batch(
    x, y=None, sample_weight=None, class_weight=None, return_dict=False
)

对单个数据批次运行一次梯度更新。

参数

  • x: 输入数据。必须是类似数组的。
  • y: 目标数据。必须是类似数组的。
  • sample_weight: 可选的与 x 长度相同的数组,包含要应用于每个样本的模型损失的权重。在时间数据的情况下,您可以传递形状为 (samples, sequence_length) 的二维数组,以对每个样本的每个时间步应用不同的权重。
  • class_weight: 可选的字典,将类索引(整数)映射到权重(浮点数),以应用于训练期间来自此类的样本的模型损失。这对于告诉模型“更多关注”来自代表性不足类的样本很有用。当指定 class_weight 并且目标的等级为 2 或更高时,y 必须是独热编码的,或者必须为稀疏类标签包含显式的最终维度 1。
  • return_dict: 如果为 True,则损失和指标结果将作为字典返回,每个键都是指标的名称。如果为 False,则它们将作为列表返回。

返回值

一个标量损失值(当没有指标并且 return_dict=False 时),一个损失和指标值的列表(如果有指标并且 return_dict=False 时),或者一个指标和损失值的字典(如果 return_dict=True 时)。


[来源]

test_on_batch 方法

Model.test_on_batch(x, y=None, sample_weight=None, return_dict=False)

在一个样本批次上测试模型。

参数

  • x: 输入数据。必须是类似数组的。
  • y: 目标数据。必须是类似数组的。
  • sample_weight: 可选的与 x 长度相同的数组,包含要应用于每个样本的模型损失的权重。在时间数据的情况下,您可以传递形状为 (samples, sequence_length) 的二维数组,以对每个样本的每个时间步应用不同的权重。
  • return_dict: 如果为 True,则损失和指标结果将作为字典返回,每个键都是指标的名称。如果为 False,则它们将作为列表返回。

返回值

一个标量损失值(当没有指标并且 return_dict=False 时),一个损失和指标值的列表(如果有指标并且 return_dict=False 时),或者一个指标和损失值的字典(如果 return_dict=True 时)。


[来源]

predict_on_batch 方法

Model.predict_on_batch(x)

返回单个样本批次的预测。

参数

  • x: 输入数据。它必须是类似数组的。

返回值

预测的 NumPy 数组。