export
方法Model.export(filepath, format="tf_saved_model", verbose=True)
创建用于推理的 TF SavedModel 工件。
注意:目前这只能与 TensorFlow 或 JAX 后端一起使用。
此方法允许您将模型导出到一个轻量级的 SavedModel 工件,该工件仅包含模型的前向传递(其 call()
方法),并且可以通过例如 TF-Serving 进行服务。前向传递在名称 serve()
下注册(请参见下面的示例)。
模型的原始代码(包括您可能使用过的任何自定义层)**不再**需要重新加载工件 - 它完全独立。
参数
str
或 pathlib.Path
对象。保存工件的路径。示例
# Create the artifact
model.export("path/to/location")
# Later, in a different process/environment...
reloaded_artifact = tf.saved_model.load("path/to/location")
predictions = reloaded_artifact.serve(input_data)
如果您想自定义服务端点,可以使用更底层的 keras.export.ExportArchive
类。export()
方法在内部依赖于 ExportArchive
。
ExportArchive
类keras.export.ExportArchive()
ExportArchive 用于写入 SavedModel 工件(例如,用于推理)。
如果您有一个要导出为 SavedModel 以供服务(例如,通过 TensorFlow-Serving)的 Keras 模型或层,您可以使用 ExportArchive
配置您需要提供的不同服务端点以及它们的签名。只需实例化一个 ExportArchive
,使用 track()
注册要使用的层或模型,然后使用 add_endpoint()
方法注册一个新的服务端点。完成后,使用 write_out()
方法保存工件。
生成的工件是 SavedModel,可以通过 tf.saved_model.load
重新加载。
示例
以下是如何导出用于推理的模型。
export_archive = ExportArchive()
export_archive.track(model)
export_archive.add_endpoint(
name="serve",
fn=model.call,
input_signature=[tf.TensorSpec(shape=(None, 3), dtype=tf.float32)],
)
export_archive.write_out("path/to/location")
# Elsewhere, we can reload the artifact and serve it.
# The endpoint we added is available as a method:
serving_model = tf.saved_model.load("path/to/location")
outputs = serving_model.serve(inputs)
以下是如何导出一个模型,该模型有一个用于推理的端点和一个用于训练模式前向传递的端点(例如,带有 dropout)。
export_archive = ExportArchive()
export_archive.track(model)
export_archive.add_endpoint(
name="call_inference",
fn=lambda x: model.call(x, training=False),
input_signature=[tf.TensorSpec(shape=(None, 3), dtype=tf.float32)],
)
export_archive.add_endpoint(
name="call_training",
fn=lambda x: model.call(x, training=True),
input_signature=[tf.TensorSpec(shape=(None, 3), dtype=tf.float32)],
)
export_archive.write_out("path/to/location")
关于资源跟踪的说明
ExportArchive
能够自动跟踪其端点使用的所有 tf.Variables
,因此大多数情况下,调用 .track(model)
并不是严格必需的。但是,如果您的模型使用查找层,例如 IntegerLookup
、StringLookup
或 TextVectorization
,则需要通过 .track(model)
明确跟踪它们。
如果您需要能够在恢复的存档上访问属性 variables
、trainable_variables
或 non_trainable_variables
,也需要进行显式跟踪。
add_endpoint
方法ExportArchive.add_endpoint(name, fn, input_signature=None, jax2tf_kwargs=None)
注册一个新的服务端点。
参数
ExportArchive
跟踪的模型/层上可用的资源(例如 tf.Variable
对象或 tf.lookup.StaticHashTable
对象)。函数的输入的形状和数据类型必须已知。为此,您可以:1)确保 fn
是一个至少调用过一次的 tf.function
,或者 2)提供一个 input_signature
参数来指定输入的形状和数据类型(请参见下文)。fn
输入的形状和数据类型。 tf.TensorSpec
对象的列表(每个 fn
的位置输入参数一个)。允许嵌套参数(请参见下面的示例,该示例显示了一个具有 2 个输入参数的功能模型)。jax2tf
的参数的字典。仅当后端为 JAX 时才支持。请参阅 jax2tf.convert
的文档。如果未提供 native_serialization
和 polymorphic_shapes
的值,则会自动计算它们。返回值
添加到存档中的包装 fn
的 tf.function
。
示例
当模型具有单个输入参数时,使用 input_signature
参数添加端点
export_archive = ExportArchive()
export_archive.track(model)
export_archive.add_endpoint(
name="serve",
fn=model.call,
input_signature=[tf.TensorSpec(shape=(None, 3), dtype=tf.float32)],
)
当模型具有两个位置输入参数时,使用 input_signature
参数添加端点
export_archive = ExportArchive()
export_archive.track(model)
export_archive.add_endpoint(
name="serve",
fn=model.call,
input_signature=[
tf.TensorSpec(shape=(None, 3), dtype=tf.float32),
tf.TensorSpec(shape=(None, 4), dtype=tf.float32),
],
)
当模型具有一个输入参数(该参数是 2 个张量的列表)(例如,具有 2 个输入的功能模型)时,使用 input_signature
参数添加端点
model = keras.Model(inputs=[x1, x2], outputs=outputs)
export_archive = ExportArchive()
export_archive.track(model)
export_archive.add_endpoint(
name="serve",
fn=model.call,
input_signature=[
[
tf.TensorSpec(shape=(None, 3), dtype=tf.float32),
tf.TensorSpec(shape=(None, 4), dtype=tf.float32),
],
],
)
这也适用于字典输入
model = keras.Model(inputs={"x1": x1, "x2": x2}, outputs=outputs)
export_archive = ExportArchive()
export_archive.track(model)
export_archive.add_endpoint(
name="serve",
fn=model.call,
input_signature=[
{
"x1": tf.TensorSpec(shape=(None, 3), dtype=tf.float32),
"x2": tf.TensorSpec(shape=(None, 4), dtype=tf.float32),
},
],
)
添加一个 tf.function
端点
@tf.function()
def serving_fn(x):
return model(x)
# The function must be traced, i.e. it must be called at least once.
serving_fn(tf.random.normal(shape=(2, 3)))
export_archive = ExportArchive()
export_archive.track(model)
export_archive.add_endpoint(name="serve", fn=serving_fn)
add_variable_collection
方法ExportArchive.add_variable_collection(name, variables)
注册一组变量,以便在重新加载后检索。
参数
tf.Variable
实例的元组/列表/集合。示例
export_archive = ExportArchive()
export_archive.track(model)
# Register an endpoint
export_archive.add_endpoint(
name="serve",
fn=model.call,
input_signature=[tf.TensorSpec(shape=(None, 3), dtype=tf.float32)],
)
# Save a variable collection
export_archive.add_variable_collection(
name="optimizer_variables", variables=model.optimizer.variables)
export_archive.write_out("path/to/location")
# Reload the object
revived_object = tf.saved_model.load("path/to/location")
# Retrieve the variables
optimizer_variables = revived_object.optimizer_variables
track
方法ExportArchive.track(resource)
跟踪层或模型的变量(和其他资产)。
默认情况下,当您调用 add_endpoint()
时,端点函数使用的所有变量都会自动跟踪。但是,查找表等非变量资产需要手动跟踪。请注意,内置 Keras 层(TextVectorization
、IntegerLookup
、StringLookup
)使用的查找表在 add_endpoint()
中会自动跟踪。
参数
write_out
方法ExportArchive.write_out(filepath, options=None, verbose=True)
将相应的 SavedModel 写入磁盘。
参数
str
或 pathlib.Path
对象。保存工件的路径。tf.saved_model.SaveOptions
对象,指定 SavedModel 保存选项。关于 TF-Serving 的说明:通过 add_endpoint()
注册的所有端点在 SavedModel 工件中对 TF-Serving 可见。此外,第一个注册的端点在别名 "serving_default"
下可见(除非手动注册了名为 "serving_default"
的端点),因为 TF-Serving 要求设置此端点。