Keras 3 API文档 / Mixed precision / Mixed precision policy API

混合精度策略 API

[源代码]

DTypePolicy

keras.dtype_policies.DTypePolicy(name=None)

Keras 层的 dtype policy。

dtype policy 决定了层的计算和变量 dtypes。每个层都有一个 policy。Policy 可以通过层的构造函数的 dtype 参数传递,也可以通过 keras.config.set_dtype_policy 函数设置全局 policy。

参数

  • name: policy 名称,它决定了计算和变量 dtypes。可以是任何 dtype 名称,例如 "float32""float64",这将导致计算和变量 dtypes 都为该 dtype。也可以是字符串 "mixed_float16""mixed_bfloat16",这将导致计算 dtype 为 float16bfloat16,变量 dtype 为 float32

通常,只有在使用混合精度时才需要与 dtype policies 进行交互,混合精度是指使用 float16 或 bfloat16 进行计算,而使用 float32 作为变量。这就是为什么 API 名称中会出现 mixed_precision 这个术语。可以通过将 "mixed_float16""mixed_bfloat16" 传递给 keras.mixed_precision.set_dtype_policy() 来启用混合精度。

>>> keras.config.set_dtype_policy("mixed_float16")
>>> layer1 = keras.layers.Dense(10)
>>> layer1.dtype_policy  # layer1 will automatically use mixed precision
<DTypePolicy "mixed_float16">
>>> # Can optionally override layer to use float32
>>> # instead of mixed precision.
>>> layer2 = keras.layers.Dense(10, dtype="float32")
>>> layer2.dtype_policy
<DTypePolicy "float32">
>>> # Set policy back to initial float32.
>>> keras.config.set_dtype_policy('float32')

在上面的示例中,将 dtype="float32" 传递给层等同于传递 dtype=keras.config.DTypePolicy("float32")。通常,将 dtype policy 名称传递给层等同于传递相应的 policy,因此没有必要显式构造 DTypePolicy 对象。


[源代码]

DTypePolicyMap

keras.dtype_policies.DTypePolicyMap(default_policy=None, policy_map=None)

类似字典的对象,将层路径映射到 DTypePolicy 实例。

DTypePolicyMap 可在层和子类的 get_config 中使用,以支持复杂的 dtype policies 配置。

例如,我们可以如下修改 layers.MultiHeadAttention 中的 get_config 来支持混合 dtype policies,例如量化。

@keras.saving.register_keras_serializable("MyPackage")
class MyMultiHeadAttention(keras.layers.MultiHeadAttention):
    def get_config(self):
        config = super().get_config()
        dtype_policy_map = dtype_policies.DTypePolicyMap()
        for layer in self._flatten_layers():
            if layer.dtype_policy.quantization_mode is not None:
                dtype_policy_map[layer.path] = layer.dtype_policy
        if len(dtype_policy_map) > 0:
            config.update({"dtype": dtype_policy_map})
        return config

内部,DTypePolicyMap 使用字符串作为键,使用 DTypePolicy 作为值。通常,用于查询的键是 Layer.path。但是,也可以将正则表达式设置为键。有关更多详细信息,请参阅 get 的 docstring。

参数

  • default_policy: 一个可选的 DTypePolicy 实例,指定默认的 dtype policy。如果未指定,则默认为 keras.config.dtype_policy()
  • policy_map: 一个可选的字典,将字符串映射到 DTypePolicy 实例。默认为 None

示例

```python
>>> from keras.src import dtype_policies
>>> bfloat16 = dtype_policies.DTypePolicy("bfloat16")
>>> float16 = dtype_policies.DTypePolicy("float16")
>>> float32 = dtype_policies.DTypePolicy("float32")
>>> policy_map = DTypePolicyMap(default_policy=float32)

使用精确路径和正则表达式模式设置 policies。

注意:“decoder”将仅匹配精确路径,而不匹配其子项。

>>> policy_map["encoder/layer_0/dense"] = bfloat16
>>> policy_map["encoder/.*"] = float16
>>> policy_map["decoder"] = bfloat16

1. 找到并直接返回精确匹配。

>>> policy_map["encoder/layer_0/dense"].name
'bfloat16'

2. 找到子层的正则表达式匹配。

它匹配“encoder/.*”模式。

>>> policy_map["encoder/attention/query"].name
'float16'

3. 不发生隐式前缀匹配。

“decoder/attention”不匹配键“decoder”。

返回默认 policy。

>>> policy_map["decoder/attention"].name
'float32'

4. 如果路径匹配多个模式,则会引发 ValueError。

>>> policy_map["encoder/attention/.*"] = bfloat16
__"encoder/attention/query" now matches two patterns__

:

__- "encoder/.*__

"

__- "encoder/attention/.*__

"

>>> try:
...     policy_map["encoder/attention/query"]
... except ValueError as e:
...     print(e)
Path 'encoder/attention/query' matches multiple dtype policy ..
----

<span style="float:right;">[[source]](https://github.com/keras-team/keras/tree/v3.12.0/keras/src/dtype_policies/dtype_policy.py#L207)</span>

### `FloatDTypePolicy` class


```python
keras.dtype_policies.FloatDTypePolicy(name=None)

Keras 层的 dtype policy。

dtype policy 决定了层的计算和变量 dtypes。每个层都有一个 policy。Policy 可以通过层的构造函数的 dtype 参数传递,也可以通过 keras.config.set_dtype_policy 函数设置全局 policy。

参数

  • name: policy 名称,它决定了计算和变量 dtypes。可以是任何 dtype 名称,例如 "float32""float64",这将导致计算和变量 dtypes 都为该 dtype。也可以是字符串 "mixed_float16""mixed_bfloat16",这将导致计算 dtype 为 float16bfloat16,变量 dtype 为 float32

通常,只有在使用混合精度时才需要与 dtype policies 进行交互,混合精度是指使用 float16 或 bfloat16 进行计算,而使用 float32 作为变量。这就是为什么 API 名称中会出现 mixed_precision 这个术语。可以通过将 "mixed_float16""mixed_bfloat16" 传递给 keras.mixed_precision.set_dtype_policy() 来启用混合精度。

>>> keras.config.set_dtype_policy("mixed_float16")
>>> layer1 = keras.layers.Dense(10)
>>> layer1.dtype_policy  # layer1 will automatically use mixed precision
<DTypePolicy "mixed_float16">
>>> # Can optionally override layer to use float32
>>> # instead of mixed precision.
>>> layer2 = keras.layers.Dense(10, dtype="float32")
>>> layer2.dtype_policy
<DTypePolicy "float32">
>>> # Set policy back to initial float32.
>>> keras.config.set_dtype_policy('float32')

在上面的示例中,将 dtype="float32" 传递给层等同于传递 dtype=keras.config.DTypePolicy("float32")。通常,将 dtype policy 名称传递给层等同于传递相应的 policy,因此没有必要显式构造 DTypePolicy 对象。


[源代码]

QuantizedDTypePolicy

keras.dtype_policies.QuantizedDTypePolicy(mode, source_name=None)

Keras 层的 dtype policy。

dtype policy 决定了层的计算和变量 dtypes。每个层都有一个 policy。Policy 可以通过层的构造函数的 dtype 参数传递,也可以通过 keras.config.set_dtype_policy 函数设置全局 policy。

参数

  • name: policy 名称,它决定了计算和变量 dtypes。可以是任何 dtype 名称,例如 "float32""float64",这将导致计算和变量 dtypes 都为该 dtype。也可以是字符串 "mixed_float16""mixed_bfloat16",这将导致计算 dtype 为 float16bfloat16,变量 dtype 为 float32

通常,只有在使用混合精度时才需要与 dtype policies 进行交互,混合精度是指使用 float16 或 bfloat16 进行计算,而使用 float32 作为变量。这就是为什么 API 名称中会出现 mixed_precision 这个术语。可以通过将 "mixed_float16""mixed_bfloat16" 传递给 keras.mixed_precision.set_dtype_policy() 来启用混合精度。

>>> keras.config.set_dtype_policy("mixed_float16")
>>> layer1 = keras.layers.Dense(10)
>>> layer1.dtype_policy  # layer1 will automatically use mixed precision
<DTypePolicy "mixed_float16">
>>> # Can optionally override layer to use float32
>>> # instead of mixed precision.
>>> layer2 = keras.layers.Dense(10, dtype="float32")
>>> layer2.dtype_policy
<DTypePolicy "float32">
>>> # Set policy back to initial float32.
>>> keras.config.set_dtype_policy('float32')

在上面的示例中,将 dtype="float32" 传递给层等同于传递 dtype=keras.config.DTypePolicy("float32")。通常,将 dtype policy 名称传递给层等同于传递相应的 policy,因此没有必要显式构造 DTypePolicy 对象。


[源代码]

QuantizedFloat8DTypePolicy

keras.dtype_policies.QuantizedFloat8DTypePolicy(
    mode, source_name=None, amax_history_length=1024
)

Keras 层的 dtype policy。

dtype policy 决定了层的计算和变量 dtypes。每个层都有一个 policy。Policy 可以通过层的构造函数的 dtype 参数传递,也可以通过 keras.config.set_dtype_policy 函数设置全局 policy。

参数

  • name: policy 名称,它决定了计算和变量 dtypes。可以是任何 dtype 名称,例如 "float32""float64",这将导致计算和变量 dtypes 都为该 dtype。也可以是字符串 "mixed_float16""mixed_bfloat16",这将导致计算 dtype 为 float16bfloat16,变量 dtype 为 float32

通常,只有在使用混合精度时才需要与 dtype policies 进行交互,混合精度是指使用 float16 或 bfloat16 进行计算,而使用 float32 作为变量。这就是为什么 API 名称中会出现 mixed_precision 这个术语。可以通过将 "mixed_float16""mixed_bfloat16" 传递给 keras.mixed_precision.set_dtype_policy() 来启用混合精度。

>>> keras.config.set_dtype_policy("mixed_float16")
>>> layer1 = keras.layers.Dense(10)
>>> layer1.dtype_policy  # layer1 will automatically use mixed precision
<DTypePolicy "mixed_float16">
>>> # Can optionally override layer to use float32
>>> # instead of mixed precision.
>>> layer2 = keras.layers.Dense(10, dtype="float32")
>>> layer2.dtype_policy
<DTypePolicy "float32">
>>> # Set policy back to initial float32.
>>> keras.config.set_dtype_policy('float32')

在上面的示例中,将 dtype="float32" 传递给层等同于传递 dtype=keras.config.DTypePolicy("float32")。通常,将 dtype policy 名称传递给层等同于传递相应的 policy,因此没有必要显式构造 DTypePolicy 对象。


[源代码]

dtype_policy 函数

keras.config.dtype_policy()

返回当前默认的 dtype policy 对象。


[源代码]

set_dtype_policy 函数

keras.config.set_dtype_policy(policy)

全局设置默认的 dtype policy。

示例

>>> keras.config.set_dtype_policy("mixed_float16")