Keras 3 API 文档 / 模型 API / 保存与序列化 / 用于推理的模型导出

用于推理的模型导出

[源代码]

export 方法

Model.export(filepath, format="tf_saved_model", verbose=True)

创建用于推理的 TF SavedModel 工件。

注意:目前这只能与 TensorFlow 或 JAX 后端一起使用。

此方法允许您将模型导出到一个轻量级的 SavedModel 工件,该工件仅包含模型的前向传递(其 call() 方法),并且可以通过例如 TF-Serving 进行服务。前向传递在名称 serve() 下注册(请参见下面的示例)。

模型的原始代码(包括您可能使用过的任何自定义层)**不再**需要重新加载工件 - 它完全独立。

参数

  • filepathstrpathlib.Path 对象。保存工件的路径。
  • verbose:是否打印导出模型的所有变量。

示例

# Create the artifact
model.export("path/to/location")

# Later, in a different process/environment...
reloaded_artifact = tf.saved_model.load("path/to/location")
predictions = reloaded_artifact.serve(input_data)

如果您想自定义服务端点,可以使用更底层的 keras.export.ExportArchive 类。export() 方法在内部依赖于 ExportArchive


[源代码]

ExportArchive

keras.export.ExportArchive()

ExportArchive 用于写入 SavedModel 工件(例如,用于推理)。

如果您有一个要导出为 SavedModel 以供服务(例如,通过 TensorFlow-Serving)的 Keras 模型或层,您可以使用 ExportArchive 配置您需要提供的不同服务端点以及它们的签名。只需实例化一个 ExportArchive,使用 track() 注册要使用的层或模型,然后使用 add_endpoint() 方法注册一个新的服务端点。完成后,使用 write_out() 方法保存工件。

生成的工件是 SavedModel,可以通过 tf.saved_model.load 重新加载。

示例

以下是如何导出用于推理的模型。

export_archive = ExportArchive()
export_archive.track(model)
export_archive.add_endpoint(
    name="serve",
    fn=model.call,
    input_signature=[tf.TensorSpec(shape=(None, 3), dtype=tf.float32)],
)
export_archive.write_out("path/to/location")

# Elsewhere, we can reload the artifact and serve it.
# The endpoint we added is available as a method:
serving_model = tf.saved_model.load("path/to/location")
outputs = serving_model.serve(inputs)

以下是如何导出一个模型,该模型有一个用于推理的端点和一个用于训练模式前向传递的端点(例如,带有 dropout)。

export_archive = ExportArchive()
export_archive.track(model)
export_archive.add_endpoint(
    name="call_inference",
    fn=lambda x: model.call(x, training=False),
    input_signature=[tf.TensorSpec(shape=(None, 3), dtype=tf.float32)],
)
export_archive.add_endpoint(
    name="call_training",
    fn=lambda x: model.call(x, training=True),
    input_signature=[tf.TensorSpec(shape=(None, 3), dtype=tf.float32)],
)
export_archive.write_out("path/to/location")

关于资源跟踪的说明

ExportArchive 能够自动跟踪其端点使用的所有 tf.Variables,因此大多数情况下,调用 .track(model) 并不是严格必需的。但是,如果您的模型使用查找层,例如 IntegerLookupStringLookupTextVectorization,则需要通过 .track(model) 明确跟踪它们。

如果您需要能够在恢复的存档上访问属性 variablestrainable_variablesnon_trainable_variables,也需要进行显式跟踪。


[源代码]

add_endpoint 方法

ExportArchive.add_endpoint(name, fn, input_signature=None, jax2tf_kwargs=None)

注册一个新的服务端点。

参数

  • name:字符串,端点的名称。
  • fn:一个函数。它只能利用 ExportArchive 跟踪的模型/层上可用的资源(例如 tf.Variable 对象或 tf.lookup.StaticHashTable 对象)。函数的输入的形状和数据类型必须已知。为此,您可以:1)确保 fn 是一个至少调用过一次的 tf.function,或者 2)提供一个 input_signature 参数来指定输入的形状和数据类型(请参见下文)。
  • input_signature:用于指定 fn 输入的形状和数据类型。 tf.TensorSpec 对象的列表(每个 fn 的位置输入参数一个)。允许嵌套参数(请参见下面的示例,该示例显示了一个具有 2 个输入参数的功能模型)。
  • jax2tf_kwargs:可选。传递给 jax2tf 的参数的字典。仅当后端为 JAX 时才支持。请参阅 jax2tf.convert 的文档。如果未提供 native_serializationpolymorphic_shapes 的值,则会自动计算它们。

返回值

添加到存档中的包装 fntf.function

示例

当模型具有单个输入参数时,使用 input_signature 参数添加端点

export_archive = ExportArchive()
export_archive.track(model)
export_archive.add_endpoint(
    name="serve",
    fn=model.call,
    input_signature=[tf.TensorSpec(shape=(None, 3), dtype=tf.float32)],
)

当模型具有两个位置输入参数时,使用 input_signature 参数添加端点

export_archive = ExportArchive()
export_archive.track(model)
export_archive.add_endpoint(
    name="serve",
    fn=model.call,
    input_signature=[
        tf.TensorSpec(shape=(None, 3), dtype=tf.float32),
        tf.TensorSpec(shape=(None, 4), dtype=tf.float32),
    ],
)

当模型具有一个输入参数(该参数是 2 个张量的列表)(例如,具有 2 个输入的功能模型)时,使用 input_signature 参数添加端点

model = keras.Model(inputs=[x1, x2], outputs=outputs)

export_archive = ExportArchive()
export_archive.track(model)
export_archive.add_endpoint(
    name="serve",
    fn=model.call,
    input_signature=[
        [
            tf.TensorSpec(shape=(None, 3), dtype=tf.float32),
            tf.TensorSpec(shape=(None, 4), dtype=tf.float32),
        ],
    ],
)

这也适用于字典输入

model = keras.Model(inputs={"x1": x1, "x2": x2}, outputs=outputs)

export_archive = ExportArchive()
export_archive.track(model)
export_archive.add_endpoint(
    name="serve",
    fn=model.call,
    input_signature=[
        {
            "x1": tf.TensorSpec(shape=(None, 3), dtype=tf.float32),
            "x2": tf.TensorSpec(shape=(None, 4), dtype=tf.float32),
        },
    ],
)

添加一个 tf.function 端点

@tf.function()
def serving_fn(x):
    return model(x)

# The function must be traced, i.e. it must be called at least once.
serving_fn(tf.random.normal(shape=(2, 3)))

export_archive = ExportArchive()
export_archive.track(model)
export_archive.add_endpoint(name="serve", fn=serving_fn)

[源代码]

add_variable_collection 方法

ExportArchive.add_variable_collection(name, variables)

注册一组变量,以便在重新加载后检索。

参数

  • name:集合的字符串名称。
  • variablestf.Variable 实例的元组/列表/集合。

示例

export_archive = ExportArchive()
export_archive.track(model)
# Register an endpoint
export_archive.add_endpoint(
    name="serve",
    fn=model.call,
    input_signature=[tf.TensorSpec(shape=(None, 3), dtype=tf.float32)],
)
# Save a variable collection
export_archive.add_variable_collection(
    name="optimizer_variables", variables=model.optimizer.variables)
export_archive.write_out("path/to/location")

# Reload the object
revived_object = tf.saved_model.load("path/to/location")
# Retrieve the variables
optimizer_variables = revived_object.optimizer_variables

[源代码]

track 方法

ExportArchive.track(resource)

跟踪层或模型的变量(和其他资产)。

默认情况下,当您调用 add_endpoint() 时,端点函数使用的所有变量都会自动跟踪。但是,查找表等非变量资产需要手动跟踪。请注意,内置 Keras 层(TextVectorizationIntegerLookupStringLookup)使用的查找表在 add_endpoint() 中会自动跟踪。

参数

  • resource:可跟踪的 TensorFlow 资源。

[源代码]

write_out 方法

ExportArchive.write_out(filepath, options=None, verbose=True)

将相应的 SavedModel 写入磁盘。

参数

  • filepathstrpathlib.Path 对象。保存工件的路径。
  • optionstf.saved_model.SaveOptions 对象,指定 SavedModel 保存选项。
  • verbose:是否打印导出 SavedModel 的所有变量。

关于 TF-Serving 的说明:通过 add_endpoint() 注册的所有端点在 SavedModel 工件中对 TF-Serving 可见。此外,第一个注册的端点在别名 "serving_default" 下可见(除非手动注册了名为 "serving_default" 的端点),因为 TF-Serving 要求设置此端点。