compile
方法Model.compile(
optimizer="rmsprop",
loss=None,
metrics=None,
loss_weights=None,
weighted_metrics=None,
run_eagerly=None,
steps_per_execution=None,
jit_compile=None,
pss_evaluation_shards=0,
**kwargs
)
配置模型进行训练。
示例
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),
loss=tf.keras.losses.BinaryCrossentropy(),
metrics=[tf.keras.metrics.BinaryAccuracy(),
tf.keras.metrics.FalseNegatives()])
参数
tf.keras.optimizers
。tf.keras.losses.Loss
实例。详见 tf.keras.losses
。损失函数是任何具有签名 loss = fn(y_true, y_pred)
的可调用对象,其中 y_true
是真实值,y_pred
是模型的预测值。y_true
应该具有形状 (batch_size, d0, .. dN)
(稀疏损失函数如 sparse categorical crossentropy 除外,它需要形状为 (batch_size, d0, .. dN-1)
的整数数组)。y_pred
应该具有形状 (batch_size, d0, .. dN)
。损失函数应返回一个浮点张量。如果使用了自定义的 Loss
实例并且 reduction 设置为 None
,则返回值的形状为 (batch_size, d0, .. dN-1)
,即每个样本或每个时间步的损失值;否则,它是一个标量。如果模型有多个输出,您可以通过传递字典或损失列表来对每个输出使用不同的损失函数。模型的最小化损失值将是所有单个损失的总和,除非指定了 loss_weights
。tf.keras.metrics.Metric
实例。详见 tf.keras.metrics
。通常您会使用 metrics=['accuracy']
。函数是任何具有签名 result = fn(y_true, y_pred)
的可调用对象。对于多输出模型,您可以通过传递字典来为不同的输出指定不同的指标,例如 metrics={'output_a':'accuracy', 'output_b':['accuracy', 'mse']}
。您也可以传递列表来为每个输出指定一个或多个指标,例如 metrics=[['accuracy'], ['accuracy', 'mse']]
或 metrics=['accuracy', ['accuracy', 'mse']]
。当您传递字符串 'accuracy' 或 'acc' 时,我们会根据目标和模型输出的形状将其转换为 tf.keras.metrics.BinaryAccuracy
、tf.keras.metrics.CategoricalAccuracy
或 tf.keras.metrics.SparseCategoricalAccuracy
中的一个。对于字符串 'crossentropy' 和 'ce',我们也进行类似的转换。此处传递的指标在评估时不考虑样本权重;如果您希望应用样本权重,可以改为通过 weighted_metrics
参数指定您的指标。loss_weights
系数指定。如果是列表,预计与模型的输出一一对应。如果是字典,预计将输出名称(字符串)映射到标量系数。sample_weight
或 class_weight
进行评估和加权的指标列表。True
,此 Model
的逻辑将不会被包装在 tf.function
中。建议保留此参数为 None
,除非您的 Model
无法在 tf.function
内运行。使用 tf.distribute.experimental.ParameterServerStrategy
时不支持 run_eagerly=True
。默认为 False
。'auto'
。每次调用 tf.function
时运行的批次数量。如果设置为 "auto",Keras 将在运行时自动调整 steps_per_execution
。在单个 tf.function
调用中运行多个批次可以显著提高 TPU 性能,尤其是在使用 ParameterServerStrategy
等分布式策略或具有大量 Python 开销的小型模型时。每次执行最多运行一个完整 epoch。如果传入的数字大于 epoch 大小,执行将被截断为 epoch 的大小。请注意,如果 steps_per_execution
设置为 N
,则 Callback.on_batch_begin
和 Callback.on_batch_end
方法将仅每 N
个批次被调用一次(即在每次 tf.function
执行之前/之后)。默认为 1
。True
,则使用 XLA 编译模型训练步骤。XLA 是一种用于机器学习的优化编译器。jit_compile
默认不启用。请注意,jit_compile=True
不一定适用于所有模型。有关支持的操作的更多信息,请参阅 XLA 文档。另请参阅已知 XLA 问题了解更多详情。tf.distribute.ParameterServerStrategy
训练。此参数设置将数据集分割成的分片数量,以实现评估的精确访问保证,这意味着模型将被应用于每个数据集元素恰好一次,即使有 worker 失败。数据集必须进行分片,以确保不同的 worker 不处理相同的数据。为获得良好性能,分片数量应至少等于 worker 数量。'auto' 值开启精确评估,并根据 worker 数量使用启发式方法确定分片数量。0 表示不提供访问保证。注意:进行精确评估时,Model.test_step
的自定义实现将被忽略。默认为 0
。fit
方法Model.fit(
x=None,
y=None,
batch_size=None,
epochs=1,
verbose="auto",
callbacks=None,
validation_split=0.0,
validation_data=None,
shuffle=True,
class_weight=None,
sample_weight=None,
initial_epoch=0,
steps_per_epoch=None,
validation_steps=None,
validation_batch_size=None,
validation_freq=1,
max_queue_size=10,
workers=1,
use_multiprocessing=False,
)
以固定数量的 epoch(数据集迭代)训练模型。
参数
tf.data
数据集。应返回一个元组,可以是 (inputs, targets)
或 (inputs, targets, sample_weights)
。(inputs, targets)
或 (inputs, targets, sample_weights)
的生成器或 keras.utils.Sequence
实例。tf.keras.utils.experimental.DatasetCreator
,它包装了一个可调用对象,该对象接受一个类型为 tf.distribute.InputContext
的参数,并返回一个 tf.data.Dataset
。当用户希望为 Dataset
指定按副本的批处理和分片逻辑时,应使用 DatasetCreator
。有关更多信息,请参阅 tf.keras.utils.experimental.DatasetCreator
文档。下面详细描述了迭代器类型(Dataset、生成器、Sequence)的解包行为。如果这些包含 sample_weights
作为第三个组件,请注意样本加权适用于 compile()
中的 weighted_metrics
参数,但不适用于 metrics
参数。如果使用 tf.distribute.experimental.ParameterServerStrategy
,x
只支持 DatasetCreator
类型。x
类似,可以是 Numpy 数组或 TensorFlow 张量。它应与 x
一致(不能同时有 Numpy 输入和张量目标,或反之)。如果 x
是数据集、生成器或 keras.utils.Sequence
实例,则不应指定 y
(因为目标将从 x
中获取)。None
。每次梯度更新的样本数量。如果未指定,batch_size
将默认为 32。如果您的数据采用数据集、生成器或 keras.utils.Sequence
实例的形式(因为它们会生成批次),则不要指定 batch_size
。x
和 y
数据的一次迭代(除非 steps_per_epoch
参数被设置为 None 以外的值)。请注意,与 initial_epoch
结合使用时,epochs
应理解为“最终 epoch”。模型并非训练 epochs
给定的迭代次数,而只是训练直到达到索引为 epochs
的 epoch。ParameterServerStrategy
一起使用时变为 2。请注意,当日志记录到文件时,进度条不是特别有用,因此在非交互式运行时(例如,在生产环境中)建议使用 verbose=2。默认为 'auto'。keras.callbacks.Callback
实例列表。在训练期间应用的回调函数列表。详见 tf.keras.callbacks
。请注意,tf.keras.callbacks.ProgbarLogger
和 tf.keras.callbacks.History
回调函数会自动创建,无需传递给 model.fit
。是否创建 tf.keras.callbacks.ProgbarLogger
取决于传递给 model.fit
的 verbose
参数。目前不支持在 tf.distribute.experimental.ParameterServerStrategy
下使用带有批次级调用的回调函数,建议用户改为实现 epoch 级调用并设置合适的 steps_per_epoch
值。x
和 y
数据中靠后的样本,在洗牌之前进行。当 x
是数据集、生成器或 keras.utils.Sequence
实例时,不支持此参数。如果同时提供了 validation_data
和 validation_split
,则 validation_data
将覆盖 validation_split
。tf.distribute.experimental.ParameterServerStrategy
尚不支持 validation_split
。validation_split
或 validation_data
提供的数据的验证损失不受噪声和 Dropout 等正则化层的影响。validation_data
将覆盖 validation_split
。validation_data
可以是: - Numpy 数组或张量的元组 (x_val, y_val)
。 - NumPy 数组的元组 (x_val, y_val, val_sample_weights)
。 - 一个 tf.data.Dataset
。 - 返回 (inputs, targets)
或 (inputs, targets, sample_weights)
的 Python 生成器或 keras.utils.Sequence
。 tf.distribute.experimental.ParameterServerStrategy
尚不支持 validation_data
。x
是生成器或 tf.data.Dataset 对象时,此参数将被忽略。'batch' 是处理 HDF5 数据限制的特殊选项;它以批次大小的块进行洗牌。当 steps_per_epoch
不是 None
时,此参数无效。class_weight
且目标具有 2 或更高维度时,y
必须进行独热编码,或者对于稀疏类别标签必须包含一个明确的最终维度 1
。(samples, sequence_length)
的 2D 数组,以便对每个样本的每个时间步应用不同的权重。当 x
是数据集、生成器或 keras.utils.Sequence
实例时,不支持此参数,请改为将样本权重作为 x
的第三个元素提供。请注意,样本加权不适用于 compile()
中通过 metrics
参数指定的指标。要将样本加权应用于您的指标,您可以改为通过 compile()
中的 weighted_metrics
指定它们。None
。在声明一个 epoch 完成并开始下一个 epoch 之前所需的总步数(样本批次)。当使用 TensorFlow 数据张量等输入张量进行训练时,默认值 None
等于数据集中样本数量除以批次大小,如果无法确定则为 1。如果 x 是一个 tf.data
数据集,并且 'steps_per_epoch' 为 None,epoch 将一直运行直到输入数据集耗尽。传递无限重复的数据集时,必须指定 steps_per_epoch
参数。如果 steps_per_epoch=-1
,训练将使用无限重复的数据集无限期运行。数组输入不支持此参数。使用 tf.distribute.experimental.ParameterServerStrategy
时: * 不支持 steps_per_epoch=None
。validation_data
且它是 tf.data
数据集时相关。在每个 epoch 结束时执行验证时,在停止之前所需的总步数(样本批次)。如果 'validation_steps' 为 None,验证将一直运行直到 validation_data
数据集耗尽。对于无限重复的数据集,它将进入无限循环。如果指定了 'validation_steps' 并且只使用了数据集的一部分,则评估将在每个 epoch 从数据集的开头开始。这确保每次都使用相同的验证样本。None
。每次验证批次的样本数量。如果未指定,将默认为 batch_size
。如果您的数据采用数据集、生成器或 keras.utils.Sequence
实例的形式(因为它们会生成批次),则不要指定 validation_batch_size
。collections.abc.Container
实例(例如列表、元组等)。如果是整数,指定在执行新的验证运行之前运行多少个训练 epoch,例如 validation_freq=2
表示每 2 个 epoch 运行一次验证。如果是容器,指定在哪些 epoch 上运行验证,例如 validation_freq=[1, 2, 10]
表示在第 1、2 和 10 个 epoch 结束时运行验证。keras.utils.Sequence
输入。生成器队列的最大大小。如果未指定,max_queue_size
将默认为 10。keras.utils.Sequence
输入。使用基于进程的线程时,启动的最大进程数。如果未指定,workers
将默认为 1。keras.utils.Sequence
输入。如果为 True
,则使用基于进程的线程。如果未指定,use_multiprocessing
将默认为 False
。请注意,由于此实现依赖于多进程,因此不应将不可 pickle 的参数传递给生成器,因为它们无法轻松传递给子进程。类迭代器输入的解包行为:一种常见的模式是将 tf.data.Dataset、生成器或 tf.keras.utils.Sequence 传递给 fit 方法的 x
参数,它们实际上不仅会生成特征 (x),还可以选择生成目标 (y) 和样本权重。TF-Keras 要求此类迭代器输出是明确的。迭代器应返回一个长度为 1、2 或 3 的元组,其中可选的第二个和第三个元素将分别用于 y 和 sample_weight。提供的任何其他类型将被包装在一个长度为一的元组中,有效地将所有内容视为 'x'。当生成字典时,它们仍然应该遵循顶层元组结构。例如:({"x0": x0, "x1": x1}, y)
。TF-Keras 不会尝试从单个字典的键中分离特征、目标和权重。一个值得注意的不支持的数据类型是 namedtuple。原因在于它既表现得像有序数据类型(元组),又表现得像映射数据类型(字典)。因此,给定一个形式为 namedtuple("example_tuple", ["y", "x"])
的 namedtuple,解释值时是否需要颠倒元素的顺序是模糊的。更糟糕的是形式为 namedtuple("other_tuple", ["x", "y", "z"])
的元组,其中不清楚该元组是打算解包为 x、y 和 sample_weight,还是作为单个元素传递给 x
。因此,数据处理代码在遇到 namedtuple 时将直接引发 ValueError。(并提供解决此问题的说明。)
返回值
一个 History
对象。其 History.history
属性是训练损失值和指标值在连续 epoch 上的记录,以及验证损失值和验证指标值(如果适用)。
异常
model.fit
被包装在 tf.function
中。evaluate
方法Model.evaluate(
x=None,
y=None,
batch_size=None,
verbose="auto",
sample_weight=None,
steps=None,
callbacks=None,
max_queue_size=10,
workers=1,
use_multiprocessing=False,
return_dict=False,
**kwargs
)
返回模型在测试模式下的损失值和指标值。
计算以批次形式进行(参见 batch_size
参数)。
参数
tf.data
数据集。应返回一个元组,可以是 (inputs, targets)
或 (inputs, targets, sample_weights)
。(inputs, targets)
或 (inputs, targets, sample_weights)
的生成器或 keras.utils.Sequence
实例。迭代器类型(Dataset、生成器、Sequence)的解包行为的更详细描述请参见 Model.fit
方法的“类迭代器输入的解包行为”部分。x
类似,可以是 Numpy 数组或 TensorFlow 张量。它应与 x
一致(不能同时有 Numpy 输入和张量目标,或反之)。如果 x
是数据集、生成器或 keras.utils.Sequence
实例,则不应指定 y
(因为目标将从迭代器/数据集中获取)。None
。每次计算批次的样本数量。如果未指定,batch_size
将默认为 32。如果您的数据采用数据集、生成器或 keras.utils.Sequence
实例的形式(因为它们会生成批次),则不要指定 batch_size
。"auto"
、0、1 或 2。详细程度模式。0 = 静默,1 = 进度条,2 = 单行。在大多数情况下,"auto"
变为 1,与 ParameterServerStrategy
一起使用时变为 2。请注意,当日志记录到文件时,进度条不是特别有用,因此在非交互式运行时(例如,在生产环境中)建议使用 verbose=2
。默认为 'auto'。(samples, sequence_length)
的 2D 数组,以便对每个样本的每个时间步应用不同的权重。当 x
是数据集时,不支持此参数,请改为将样本权重作为 x
的第三个元素传递。None
。在声明评估轮次完成之前所需的总步数(样本批次)。如果使用默认值 None
,则忽略此参数。如果 x 是一个 tf.data
数据集并且 steps
为 None,则 'evaluate' 将一直运行直到数据集耗尽。数组输入不支持此参数。keras.callbacks.Callback
实例列表。在评估期间应用的回调函数列表。详见 callbacks。keras.utils.Sequence
输入。生成器队列的最大大小。如果未指定,max_queue_size
将默认为 10。keras.utils.Sequence
输入。使用基于进程的线程时,启动的最大进程数。如果未指定,workers
将默认为 1。keras.utils.Sequence
输入。如果为 True
,则使用基于进程的线程。如果未指定,use_multiprocessing
将默认为 False
。请注意,由于此实现依赖于多进程,因此不应将不可 pickle 的参数传递给生成器,因为它们无法轻松传递给子进程。True
,损失和指标结果将以字典形式返回,每个键是指标名称。如果为 False
,它们将以列表形式返回。请参阅 Model.fit
方法中关于“类迭代器输入的解包行为”的讨论。
返回值
标量测试损失(如果模型只有一个输出且没有指标)或标量列表(如果模型有多个输出和/或指标)。model.metrics_names
属性将为您提供标量输出的显示标签。
异常
model.evaluate
被包装在 tf.function
中。predict
方法Model.predict(
x,
batch_size=None,
verbose="auto",
steps=None,
callbacks=None,
max_queue_size=10,
workers=1,
use_multiprocessing=False,
)
为输入样本生成输出预测。
计算以批次形式进行。此方法旨在用于处理大量输入的批处理。不建议在遍历数据并每次处理少量输入的循环中使用此方法。
对于 fits 在一个批次中的少量输入,可以直接使用 __call__()
以获得更快的执行速度,例如 model(x)
,或者如果您有在推理期间行为不同的层(如 tf.keras.layers.BatchNormalization
),可以使用 model(x, training=False)
。您可以将单个模型调用与 tf.function
结合使用,以在内循环中获得额外的性能。如果在调用模型后需要访问 numpy 数组值而不是张量,可以使用 tensor.numpy()
获取 eager 张量的 numpy 数组值。
另外,请注意测试损失不受噪声和 dropout 等正则化层的影响。
注意:有关 Model
方法 predict()
和 __call__()
之间区别的更多详情,请参阅此常见问题解答条目。
参数
tf.data
数据集。keras.utils.Sequence
实例。迭代器类型(Dataset、生成器、Sequence)的解包行为的更详细描述请参见 Model.fit
方法的“类迭代器输入的解包行为”部分。None
。每次批次的样本数量。如果未指定,batch_size
将默认为 32。如果您的数据采用数据集、生成器或 keras.utils.Sequence
实例的形式(因为它们会生成批次),则不要指定 batch_size
。"auto"
、0、1 或 2。详细程度模式。0 = 静默,1 = 进度条,2 = 单行。在大多数情况下,"auto"
变为 1,与 ParameterServerStrategy
一起使用时变为 2。请注意,当日志记录到文件时,进度条不是特别有用,因此在非交互式运行时(例如,在生产环境中)建议使用 verbose=2
。默认为 'auto'。None
,则忽略此参数。如果 x 是一个 tf.data
数据集并且 steps
为 None,则 predict()
将一直运行直到输入数据集耗尽。keras.callbacks.Callback
实例列表。在预测期间应用的回调函数列表。详见 callbacks。keras.utils.Sequence
输入。生成器队列的最大大小。如果未指定,max_queue_size
将默认为 10。keras.utils.Sequence
输入。使用基于进程的线程时,启动的最大进程数。如果未指定,workers
将默认为 1。keras.utils.Sequence
输入。如果为 True
,则使用基于进程的线程。如果未指定,use_multiprocessing
将默认为 False
。请注意,由于此实现依赖于多进程,因此不应将不可 pickle 的参数传递给生成器,因为它们无法轻松传递给子进程。请参阅 Model.fit
方法中关于“类迭代器输入的解包行为”的讨论。请注意,Model.predict 使用与 Model.fit
和 Model.evaluate
相同的解释规则,因此对于这三个方法,输入必须是明确的。
返回值
Numpy 预测数组。
异常
model.predict
被包装在 tf.function
中。train_on_batch
方法Model.train_on_batch(
x,
y=None,
sample_weight=None,
class_weight=None,
reset_metrics=True,
return_dict=False,
)
对单个数据批次运行一次梯度更新。
参数
x
类似,可以是 Numpy 数组或 TensorFlow 张量。class_weight
且目标具有 2 或更高维度时,y
必须进行独热编码,或者对于稀疏类别标签必须包含一个明确的最终维度 1
。True
,返回的指标将仅针对此批次。如果为 False
,指标将在批次之间进行有状态累积。True
,损失和指标结果将以字典形式返回,每个键是指标名称。如果为 False
,它们将以列表形式返回。返回值
标量训练损失(如果模型只有一个输出且没有指标)或标量列表(如果模型有多个输出和/或指标)。model.metrics_names
属性将为您提供标量输出的显示标签。
异常
model.train_on_batch
被包装在 tf.function
中。test_on_batch
方法Model.test_on_batch(
x, y=None, sample_weight=None, reset_metrics=True, return_dict=False
)
在单个样本批次上测试模型。
参数
x
类似,可以是 Numpy 数组或 TensorFlow 张量。它应与 x
一致(不能同时有 Numpy 输入和张量目标,或反之)。True
,返回的指标将仅针对此批次。如果为 False
,指标将在批次之间进行有状态累积。True
,损失和指标结果将以字典形式返回,每个键是指标名称。如果为 False
,它们将以列表形式返回。返回值
标量测试损失(如果模型只有一个输出且没有指标)或标量列表(如果模型有多个输出和/或指标)。model.metrics_names
属性将为您提供标量输出的显示标签。
异常
model.test_on_batch
被包装在 tf.function
中。predict_on_batch
方法Model.predict_on_batch(x)
返回单个样本批次的预测值。
参数
返回值
Numpy 预测数组。
异常
model.predict_on_batch
被包装在 tf.function
中。run_eagerly
属性tf_keras.Model.run_eagerly
可设置属性,指示模型是否应以 eager 模式运行。
以 eager 模式运行意味着您的模型将像 Python 代码一样一步一步地运行。您的模型可能会运行得更慢,但通过逐个步入层调用来调试它应该会更容易。
默认情况下,我们将尝试将您的模型编译为静态图以提供最佳执行性能。
返回值
布尔值,指示模型是否应以 eager 模式运行。