Keras 3 API 文档 / 混合精度 / 混合精度策略 API

混合精度策略 API

[源代码]

DTypePolicy

keras.dtype_policies.DTypePolicy(name=None)

Keras 层的数据类型策略。

数据类型策略决定了层的计算和变量数据类型。每个层都有一个策略。策略可以传递给层构造函数的 dtype 参数,或者使用 keras.config.set_dtype_policy 设置全局策略。

参数

  • name: 策略名称,它决定了计算和变量的数据类型。可以是任何数据类型名称,例如 "float32""float64",这将导致计算和变量数据类型都为该数据类型。也可以是字符串 "mixed_float16""mixed_bfloat16",这将导致计算数据类型为 float16bfloat16,变量数据类型为 float32

通常,只有在使用混合精度时才需要与数据类型策略交互,混合精度是指使用 float16 或 bfloat16 进行计算,使用 float32 进行变量存储。这就是 API 名称中出现 mixed_precision 的原因。可以通过将 "mixed_float16""mixed_bfloat16" 传递给 keras.mixed_precision.set_dtype_policy() 来启用混合精度。

>>> keras.config.set_dtype_policy("mixed_float16")
>>> layer1 = keras.layers.Dense(10)
>>> layer1.dtype_policy  # layer1 will automatically use mixed precision
<DTypePolicy "mixed_float16">
>>> # Can optionally override layer to use float32
>>> # instead of mixed precision.
>>> layer2 = keras.layers.Dense(10, dtype="float32")
>>> layer2.dtype_policy
<DTypePolicy "float32">
>>> # Set policy back to initial float32.
>>> keras.config.set_dtype_policy('float32')

在上面的示例中,将 dtype="float32" 传递给层等效于传递 dtype=keras.config.DTypePolicy("float32")。通常,将数据类型策略名称传递给层等效于传递相应的策略,因此无需显式构造 DTypePolicy 对象。


[源代码]

DTypePolicyMap

keras.dtype_policies.DTypePolicyMap(default_policy=None, policy_map=None)

将层路径映射到 DTypePolicy 实例的类似字典的对象。

DTypePolicyMap 可用于层及其子类中的 get_config 以支持复杂的数据类型策略配置。

例如,我们可以修改 layers.MultiHeadAttention 中的 get_config 如下,以支持数据类型策略的混合,例如量化。

@keras.saving.register_keras_serializable("MyPackage")
class MyMultiHeadAttention(keras.layers.MultiHeadAttention):
    def get_config(self):
        config = super().get_config()
        dtype_policy_map = dtype_policies.DTypePolicyMap()
        for layer in self._flatten_layers():
            if layer.dtype_policy.quantization_mode is not None:
                dtype_policy_map[layer.path] = layer.dtype_policy
        if len(dtype_policy_map) > 0:
            config.update({"dtype": dtype_policy_map})
        return config

在内部,DTypePolicyMap 使用字符串作为键,使用 DTypePolicy 作为值。通常,用于查询的键是 Layer.path。但是,也可以将正则表达式设置为键。有关更多详细信息,请参阅 get 的文档字符串。

请参见下面的用法示例。您可以定义 DTypePolicy 的命名方案,然后检索相应 DTypePolicy 实例。

dtype_policy_map = DTypePolicyMap()
dtype_policy_map["layer/dense_0"] = DTypePolicy("bfloat16")
dtype_policy_map["layer/dense_1"] = QuantizedDTypePolicy("int8", "bfloat16")

policy_0 = dtype_policy_map["layer/dense_0"]
policy_1 = dtype_policy_map["layer/dense_1"]
policy_2 = dtype_policy_map["layer/dense_2"]  # No hit
assert policy_0 == DTypePolicy("bfloat16")
assert policy_1 == QuantizedDTypePolicy("int8", "bfloat16")
assert policy_2 == keras.config.dtype_policy()

参数

  • default_policy: 可选的 DTypePolicy 实例,指定默认数据类型策略。如果未指定,则该值将默认为 keras.config.dtype_policy()
  • policy_map: 可选的字典,将字符串映射到 DTypePolicy 实例。默认为 None

[源代码]

FloatDTypePolicy

keras.dtype_policies.FloatDTypePolicy(name=None)

Keras 层的数据类型策略。

数据类型策略决定了层的计算和变量数据类型。每个层都有一个策略。策略可以传递给层构造函数的 dtype 参数,或者使用 keras.config.set_dtype_policy 设置全局策略。

参数

  • name: 策略名称,它决定了计算和变量的数据类型。可以是任何数据类型名称,例如 "float32""float64",这将导致计算和变量数据类型都为该数据类型。也可以是字符串 "mixed_float16""mixed_bfloat16",这将导致计算数据类型为 float16bfloat16,变量数据类型为 float32

通常,只有在使用混合精度时才需要与数据类型策略交互,混合精度是指使用 float16 或 bfloat16 进行计算,使用 float32 进行变量存储。这就是 API 名称中出现 mixed_precision 的原因。可以通过将 "mixed_float16""mixed_bfloat16" 传递给 keras.mixed_precision.set_dtype_policy() 来启用混合精度。

>>> keras.config.set_dtype_policy("mixed_float16")
>>> layer1 = keras.layers.Dense(10)
>>> layer1.dtype_policy  # layer1 will automatically use mixed precision
<DTypePolicy "mixed_float16">
>>> # Can optionally override layer to use float32
>>> # instead of mixed precision.
>>> layer2 = keras.layers.Dense(10, dtype="float32")
>>> layer2.dtype_policy
<DTypePolicy "float32">
>>> # Set policy back to initial float32.
>>> keras.config.set_dtype_policy('float32')

在上面的示例中,将 dtype="float32" 传递给层等效于传递 dtype=keras.config.DTypePolicy("float32")。通常,将数据类型策略名称传递给层等效于传递相应的策略,因此无需显式构造 DTypePolicy 对象。


[源代码]

QuantizedDTypePolicy

keras.dtype_policies.QuantizedDTypePolicy(mode, source_name=None)

Keras 层的数据类型策略。

数据类型策略决定了层的计算和变量数据类型。每个层都有一个策略。策略可以传递给层构造函数的 dtype 参数,或者使用 keras.config.set_dtype_policy 设置全局策略。

参数

  • name: 策略名称,它决定了计算和变量的数据类型。可以是任何数据类型名称,例如 "float32""float64",这将导致计算和变量数据类型都为该数据类型。也可以是字符串 "mixed_float16""mixed_bfloat16",这将导致计算数据类型为 float16bfloat16,变量数据类型为 float32

通常,只有在使用混合精度时才需要与数据类型策略交互,混合精度是指使用 float16 或 bfloat16 进行计算,使用 float32 进行变量存储。这就是 API 名称中出现 mixed_precision 的原因。可以通过将 "mixed_float16""mixed_bfloat16" 传递给 keras.mixed_precision.set_dtype_policy() 来启用混合精度。

>>> keras.config.set_dtype_policy("mixed_float16")
>>> layer1 = keras.layers.Dense(10)
>>> layer1.dtype_policy  # layer1 will automatically use mixed precision
<DTypePolicy "mixed_float16">
>>> # Can optionally override layer to use float32
>>> # instead of mixed precision.
>>> layer2 = keras.layers.Dense(10, dtype="float32")
>>> layer2.dtype_policy
<DTypePolicy "float32">
>>> # Set policy back to initial float32.
>>> keras.config.set_dtype_policy('float32')

在上面的示例中,将 dtype="float32" 传递给层等效于传递 dtype=keras.config.DTypePolicy("float32")。通常,将数据类型策略名称传递给层等效于传递相应的策略,因此无需显式构造 DTypePolicy 对象。


[源代码]

QuantizedFloat8DTypePolicy

keras.dtype_policies.QuantizedFloat8DTypePolicy(
    mode, source_name=None, amax_history_length=1024
)

Keras 层的数据类型策略。

数据类型策略决定了层的计算和变量数据类型。每个层都有一个策略。策略可以传递给层构造函数的 dtype 参数,或者使用 keras.config.set_dtype_policy 设置全局策略。

参数

  • name: 策略名称,它决定了计算和变量的数据类型。可以是任何数据类型名称,例如 "float32""float64",这将导致计算和变量数据类型都为该数据类型。也可以是字符串 "mixed_float16""mixed_bfloat16",这将导致计算数据类型为 float16bfloat16,变量数据类型为 float32

通常,只有在使用混合精度时才需要与数据类型策略交互,混合精度是指使用 float16 或 bfloat16 进行计算,使用 float32 进行变量存储。这就是 API 名称中出现 mixed_precision 的原因。可以通过将 "mixed_float16""mixed_bfloat16" 传递给 keras.mixed_precision.set_dtype_policy() 来启用混合精度。

>>> keras.config.set_dtype_policy("mixed_float16")
>>> layer1 = keras.layers.Dense(10)
>>> layer1.dtype_policy  # layer1 will automatically use mixed precision
<DTypePolicy "mixed_float16">
>>> # Can optionally override layer to use float32
>>> # instead of mixed precision.
>>> layer2 = keras.layers.Dense(10, dtype="float32")
>>> layer2.dtype_policy
<DTypePolicy "float32">
>>> # Set policy back to initial float32.
>>> keras.config.set_dtype_policy('float32')

在上面的示例中,将 dtype="float32" 传递给层等效于传递 dtype=keras.config.DTypePolicy("float32")。通常,将数据类型策略名称传递给层等效于传递相应的策略,因此无需显式构造 DTypePolicy 对象。


[源代码]

dtype_policy 函数

keras.config.dtype_policy()

返回当前默认数据类型策略对象。


[源代码]

set_dtype_policy 函数

keras.config.set_dtype_policy(policy)

全局设置默认数据类型策略。

示例

>>> keras.config.set_dtype_policy("mixed_float16")