Keras 3 API 文档 / 模型 API / 保存与序列化 / 模型的推理导出

模型的推理导出

[source]

export 方法

Model.export(
    filepath, format="tf_saved_model", verbose=None, input_signature=None, **kwargs
)

将模型导出为用于推理的 artifact(工件)。

参数

  • filepathstrpathlib.Path 对象。保存 artifact(工件)的路径。
  • formatstr。导出格式。支持的值:"tf_saved_model""onnx"。默认为 "tf_saved_model"
  • verbosebool。导出过程中是否打印消息。默认为 None,此时使用不同后端和格式设置的默认值。
  • input_signature:可选。指定模型输入的形状和 dtype。可以是 keras.InputSpectf.TensorSpecbackend.KerasTensor 或后端张量的结构。如果未提供,将自动计算。默认为 None
  • **kwargs:额外的关键字参数
    • 特别适用于 JAX 后端和 format="tf_saved_model": - is_static:可选的 bool 类型。指示 fn 是否是静态的。如果 fn 涉及状态更新(例如,RNG 种子和计数器),则设置为 False。 - jax2tf_kwargs:可选的 dict 类型。jax2tf.convert 的参数。请参阅 jax2tf.convert 的文档。如果未提供 native_serializationpolymorphic_shapes,将自动计算。

注意: 此功能目前仅支持 TensorFlow、JAX 和 Torch 后端。

注意: 请注意,当使用 format="onnx"verbose=True 和 Torch 后端时,导出的 artifact(工件)可能包含来自本地文件系统的信息。

示例

以下是如何导出用于推理的 TensorFlow SavedModel。

# Export the model as a TensorFlow SavedModel artifact
model.export("path/to/location", format="tf_saved_model")

# Load the artifact in a different process/environment
reloaded_artifact = tf.saved_model.load("path/to/location")
predictions = reloaded_artifact.serve(input_data)

以下是如何导出用于推理的 ONNX 模型。

# Export the model as a ONNX artifact
model.export("path/to/location", format="onnx")

# Load the artifact in a different process/environment
ort_session = onnxruntime.InferenceSession("path/to/location")
ort_inputs = {
    k.name: v for k, v in zip(ort_session.get_inputs(), input_data)
}
predictions = ort_session.run(None, ort_inputs)

[source]

ExportArchive

keras.export.ExportArchive()

ExportArchive 用于写入 SavedModel artifact(例如,用于推理)。

如果您有一个 Keras 模型或层,想要将其导出为 SavedModel 以进行 Serving(例如通过 TensorFlow-Serving),可以使用 ExportArchive 来配置需要提供的不同 Serving endpoints(端点)及其签名。只需实例化一个 ExportArchive,使用 track() 注册要使用的层或模型,然后使用 add_endpoint() 方法注册新的 Serving endpoint(端点)。完成后,使用 write_out() 方法保存 artifact(工件)。

生成的 artifact(工件)是 SavedModel,可以通过 tf.saved_model.load 重新加载。

示例

以下是如何导出用于推理的模型。

export_archive = ExportArchive()
export_archive.track(model)
export_archive.add_endpoint(
    name="serve",
    fn=model.call,
    input_signature=[keras.InputSpec(shape=(None, 3), dtype="float32")],
)
export_archive.write_out("path/to/location")

# Elsewhere, we can reload the artifact and serve it.
# The endpoint we added is available as a method:
serving_model = tf.saved_model.load("path/to/location")
outputs = serving_model.serve(inputs)

以下是如何导出具有一个用于推理的 endpoint(端点)和一个用于训练模式正向传播(例如启用 dropout)的 endpoint(端点)的模型。

export_archive = ExportArchive()
export_archive.track(model)
export_archive.add_endpoint(
    name="call_inference",
    fn=lambda x: model.call(x, training=False),
    input_signature=[keras.InputSpec(shape=(None, 3), dtype="float32")],
)
export_archive.add_endpoint(
    name="call_training",
    fn=lambda x: model.call(x, training=True),
    input_signature=[keras.InputSpec(shape=(None, 3), dtype="float32")],
)
export_archive.write_out("path/to/location")

关于资源跟踪的说明

ExportArchive 能够自动跟踪其 endpoints(端点)使用的所有 keras.Variables,因此大多数时候并非严格需要调用 .track(model)。但是,如果您的模型使用了查找层(例如 IntegerLookupStringLookupTextVectorization),则需要通过 .track(model) 显式跟踪。

如果您需要在重新加载的 archive 上访问 variablestrainable_variablesnon_trainable_variables 属性,也需要显式跟踪。


[source]

add_endpoint 方法

ExportArchive.add_endpoint(name, fn, input_signature=None, **kwargs)

注册新的 Serving endpoint(端点)。

参数

  • namestr 类型。endpoint(端点)的名称。
  • fn:一个可调用对象。它应该只利用 ExportArchive 跟踪的模型/层上可用的资源(例如 keras.Variable 对象或 tf.lookup.StaticHashTable 对象)(您可以调用 .track(model) 来跟踪新模型)。必须知道函数输入的形状和 dtype。为此,您可以 1) 确保 fn 是一个至少已调用过一次的 tf.function,或者 2) 提供一个 input_signature 参数来指定输入的形状和 dtype(见下文)。
  • input_signature:可选。指定 fn 的形状和 dtype。可以是 keras.InputSpectf.TensorSpecbackend.KerasTensor 或后端张量的结构(请参阅下面展示具有 2 个输入参数的 Functional 模型的示例)。如果未提供,则 fn 必须是至少已调用过一次的 tf.function。默认为 None
  • **kwargs:额外的关键字参数
    • 特别适用于 JAX 后端: - is_static:可选的 bool 类型。指示 fn 是否是静态的。如果 fn 涉及状态更新(例如,RNG 种子),则设置为 False。 - jax2tf_kwargs:可选的 dict 类型。jax2tf.convert 的参数。请参阅 jax2tf.convert。如果未提供 native_serializationpolymorphic_shapes,将自动计算。

返回值

添加到 archive 的、包装 fntf.function

示例

当模型具有单个输入参数时,使用 input_signature 参数添加 endpoint(端点)

export_archive = ExportArchive()
export_archive.track(model)
export_archive.add_endpoint(
    name="serve",
    fn=model.call,
    input_signature=[keras.InputSpec(shape=(None, 3), dtype="float32")],
)

当模型具有两个位置输入参数时,使用 input_signature 参数添加 endpoint(端点)

export_archive = ExportArchive()
export_archive.track(model)
export_archive.add_endpoint(
    name="serve",
    fn=model.call,
    input_signature=[
        keras.InputSpec(shape=(None, 3), dtype="float32"),
        keras.InputSpec(shape=(None, 4), dtype="float32"),
    ],
)

当模型具有一个输入参数,该参数是包含 2 个张量的列表时(例如具有 2 个输入的 Functional 模型),使用 input_signature 参数添加 endpoint(端点)

model = keras.Model(inputs=[x1, x2], outputs=outputs)

export_archive = ExportArchive()
export_archive.track(model)
export_archive.add_endpoint(
    name="serve",
    fn=model.call,
    input_signature=[
        [
            keras.InputSpec(shape=(None, 3), dtype="float32"),
            keras.InputSpec(shape=(None, 4), dtype="float32"),
        ],
    ],
)

这也适用于字典输入

model = keras.Model(inputs={"x1": x1, "x2": x2}, outputs=outputs)

export_archive = ExportArchive()
export_archive.track(model)
export_archive.add_endpoint(
    name="serve",
    fn=model.call,
    input_signature=[
        {
            "x1": keras.InputSpec(shape=(None, 3), dtype="float32"),
            "x2": keras.InputSpec(shape=(None, 4), dtype="float32"),
        },
    ],
)

添加一个作为 tf.function 的 endpoint(端点)

@tf.function()
def serving_fn(x):
    return model(x)

# The function must be traced, i.e. it must be called at least once.
serving_fn(tf.random.normal(shape=(2, 3)))

export_archive = ExportArchive()
export_archive.track(model)
export_archive.add_endpoint(name="serve", fn=serving_fn)

将模型与可以使用 TensorFlow 资源的某些 TensorFlow 预处理相结合

lookup_table = tf.lookup.StaticHashTable(initializer, default_value=0.0)

export_archive = ExportArchive()
model_fn = export_archive.track_and_add_endpoint(
    "model_fn",
    model,
    input_signature=[tf.TensorSpec(shape=(None, 5), dtype=tf.float32)],
)
export_archive.track(lookup_table)

@tf.function()
def serving_fn(x):
    x = lookup_table.lookup(x)
    return model_fn(x)

export_archive.add_endpoint(name="serve", fn=serving_fn)

[source]

add_variable_collection 方法

ExportArchive.add_variable_collection(name, variables)

注册一组在重新加载后要检索的变量。

参数

  • name:集合的字符串名称。
  • variables:一个 keras.Variable 实例的 tuple/list/set。

示例

export_archive = ExportArchive()
export_archive.track(model)
# Register an endpoint
export_archive.add_endpoint(
    name="serve",
    fn=model.call,
    input_signature=[keras.InputSpec(shape=(None, 3), dtype="float32")],
)
# Save a variable collection
export_archive.add_variable_collection(
    name="optimizer_variables", variables=model.optimizer.variables)
export_archive.write_out("path/to/location")

# Reload the object
revived_object = tf.saved_model.load("path/to/location")
# Retrieve the variables
optimizer_variables = revived_object.optimizer_variables

[source]

track 方法

ExportArchive.track(resource)

跟踪变量(层或模型的)和其他资产。

默认情况下,当您调用 add_endpoint() 时,endpoint(端点)函数使用的所有变量都会自动跟踪。但是,非变量资产(例如查找表)需要手动跟踪。请注意,内置 Keras 层(TextVectorizationIntegerLookupStringLookup)使用的查找表由 add_endpoint() 自动跟踪。

参数

  • resource:一个层、模型或 TensorFlow 可跟踪资源。

[source]

write_out 方法

ExportArchive.write_out(filepath, options=None, verbose=True)

将相应的 SavedModel 写入磁盘。

参数

  • filepathstrpathlib.Path 对象。保存 artifact(工件)的路径。
  • optionstf.saved_model.SaveOptions 对象,指定 SavedModel 保存选项。
  • verbose:是否打印导出的 SavedModel 的所有变量。

关于 TF-Serving 的说明:通过 add_endpoint() 注册的所有 endpoint(端点)在 SavedModel artifact(工件)中对于 TF-Serving 都是可见的。此外,注册的第一个 endpoint(端点)将以别名 "serving_default" 显示(除非已手动注册了名为 "serving_default" 的 endpoint(端点)),因为 TF-Serving 需要设置此 endpoint(端点)。