DTypePolicy
类keras.dtype_policies.DTypePolicy(name=None)
Keras 层的数据类型策略。
数据类型策略决定了层的计算和变量数据类型。每个层都有一个策略。策略可以传递给层构造函数的 dtype
参数,或者使用 keras.config.set_dtype_policy
设置全局策略。
参数
"float32"
或 "float64"
,这将导致计算和变量数据类型都为该数据类型。也可以是字符串 "mixed_float16"
或 "mixed_bfloat16"
,这将导致计算数据类型为 float16
或 bfloat16
,变量数据类型为 float32
。通常,只有在使用混合精度时才需要与数据类型策略交互,混合精度是指使用 float16 或 bfloat16 进行计算,使用 float32 进行变量存储。这就是 API 名称中出现 mixed_precision
的原因。可以通过将 "mixed_float16"
或 "mixed_bfloat16"
传递给 keras.mixed_precision.set_dtype_policy()
来启用混合精度。
>>> keras.config.set_dtype_policy("mixed_float16")
>>> layer1 = keras.layers.Dense(10)
>>> layer1.dtype_policy # layer1 will automatically use mixed precision
<DTypePolicy "mixed_float16">
>>> # Can optionally override layer to use float32
>>> # instead of mixed precision.
>>> layer2 = keras.layers.Dense(10, dtype="float32")
>>> layer2.dtype_policy
<DTypePolicy "float32">
>>> # Set policy back to initial float32.
>>> keras.config.set_dtype_policy('float32')
在上面的示例中,将 dtype="float32"
传递给层等效于传递 dtype=keras.config.DTypePolicy("float32")
。通常,将数据类型策略名称传递给层等效于传递相应的策略,因此无需显式构造 DTypePolicy
对象。
DTypePolicyMap
类keras.dtype_policies.DTypePolicyMap(default_policy=None, policy_map=None)
将层路径映射到 DTypePolicy
实例的类似字典的对象。
DTypePolicyMap
可用于层及其子类中的 get_config
以支持复杂的数据类型策略配置。
例如,我们可以修改 layers.MultiHeadAttention
中的 get_config
如下,以支持数据类型策略的混合,例如量化。
@keras.saving.register_keras_serializable("MyPackage")
class MyMultiHeadAttention(keras.layers.MultiHeadAttention):
def get_config(self):
config = super().get_config()
dtype_policy_map = dtype_policies.DTypePolicyMap()
for layer in self._flatten_layers():
if layer.dtype_policy.quantization_mode is not None:
dtype_policy_map[layer.path] = layer.dtype_policy
if len(dtype_policy_map) > 0:
config.update({"dtype": dtype_policy_map})
return config
在内部,DTypePolicyMap
使用字符串作为键,使用 DTypePolicy
作为值。通常,用于查询的键是 Layer.path
。但是,也可以将正则表达式设置为键。有关更多详细信息,请参阅 get
的文档字符串。
请参见下面的用法示例。您可以定义 DTypePolicy
的命名方案,然后检索相应 DTypePolicy
实例。
dtype_policy_map = DTypePolicyMap()
dtype_policy_map["layer/dense_0"] = DTypePolicy("bfloat16")
dtype_policy_map["layer/dense_1"] = QuantizedDTypePolicy("int8", "bfloat16")
policy_0 = dtype_policy_map["layer/dense_0"]
policy_1 = dtype_policy_map["layer/dense_1"]
policy_2 = dtype_policy_map["layer/dense_2"] # No hit
assert policy_0 == DTypePolicy("bfloat16")
assert policy_1 == QuantizedDTypePolicy("int8", "bfloat16")
assert policy_2 == keras.config.dtype_policy()
参数
DTypePolicy
实例,指定默认数据类型策略。如果未指定,则该值将默认为 keras.config.dtype_policy()
。DTypePolicy
实例。默认为 None
FloatDTypePolicy
类keras.dtype_policies.FloatDTypePolicy(name=None)
Keras 层的数据类型策略。
数据类型策略决定了层的计算和变量数据类型。每个层都有一个策略。策略可以传递给层构造函数的 dtype
参数,或者使用 keras.config.set_dtype_policy
设置全局策略。
参数
"float32"
或 "float64"
,这将导致计算和变量数据类型都为该数据类型。也可以是字符串 "mixed_float16"
或 "mixed_bfloat16"
,这将导致计算数据类型为 float16
或 bfloat16
,变量数据类型为 float32
。通常,只有在使用混合精度时才需要与数据类型策略交互,混合精度是指使用 float16 或 bfloat16 进行计算,使用 float32 进行变量存储。这就是 API 名称中出现 mixed_precision
的原因。可以通过将 "mixed_float16"
或 "mixed_bfloat16"
传递给 keras.mixed_precision.set_dtype_policy()
来启用混合精度。
>>> keras.config.set_dtype_policy("mixed_float16")
>>> layer1 = keras.layers.Dense(10)
>>> layer1.dtype_policy # layer1 will automatically use mixed precision
<DTypePolicy "mixed_float16">
>>> # Can optionally override layer to use float32
>>> # instead of mixed precision.
>>> layer2 = keras.layers.Dense(10, dtype="float32")
>>> layer2.dtype_policy
<DTypePolicy "float32">
>>> # Set policy back to initial float32.
>>> keras.config.set_dtype_policy('float32')
在上面的示例中,将 dtype="float32"
传递给层等效于传递 dtype=keras.config.DTypePolicy("float32")
。通常,将数据类型策略名称传递给层等效于传递相应的策略,因此无需显式构造 DTypePolicy
对象。
QuantizedDTypePolicy
类keras.dtype_policies.QuantizedDTypePolicy(mode, source_name=None)
Keras 层的数据类型策略。
数据类型策略决定了层的计算和变量数据类型。每个层都有一个策略。策略可以传递给层构造函数的 dtype
参数,或者使用 keras.config.set_dtype_policy
设置全局策略。
参数
"float32"
或 "float64"
,这将导致计算和变量数据类型都为该数据类型。也可以是字符串 "mixed_float16"
或 "mixed_bfloat16"
,这将导致计算数据类型为 float16
或 bfloat16
,变量数据类型为 float32
。通常,只有在使用混合精度时才需要与数据类型策略交互,混合精度是指使用 float16 或 bfloat16 进行计算,使用 float32 进行变量存储。这就是 API 名称中出现 mixed_precision
的原因。可以通过将 "mixed_float16"
或 "mixed_bfloat16"
传递给 keras.mixed_precision.set_dtype_policy()
来启用混合精度。
>>> keras.config.set_dtype_policy("mixed_float16")
>>> layer1 = keras.layers.Dense(10)
>>> layer1.dtype_policy # layer1 will automatically use mixed precision
<DTypePolicy "mixed_float16">
>>> # Can optionally override layer to use float32
>>> # instead of mixed precision.
>>> layer2 = keras.layers.Dense(10, dtype="float32")
>>> layer2.dtype_policy
<DTypePolicy "float32">
>>> # Set policy back to initial float32.
>>> keras.config.set_dtype_policy('float32')
在上面的示例中,将 dtype="float32"
传递给层等效于传递 dtype=keras.config.DTypePolicy("float32")
。通常,将数据类型策略名称传递给层等效于传递相应的策略,因此无需显式构造 DTypePolicy
对象。
QuantizedFloat8DTypePolicy
类keras.dtype_policies.QuantizedFloat8DTypePolicy(
mode, source_name=None, amax_history_length=1024
)
Keras 层的数据类型策略。
数据类型策略决定了层的计算和变量数据类型。每个层都有一个策略。策略可以传递给层构造函数的 dtype
参数,或者使用 keras.config.set_dtype_policy
设置全局策略。
参数
"float32"
或 "float64"
,这将导致计算和变量数据类型都为该数据类型。也可以是字符串 "mixed_float16"
或 "mixed_bfloat16"
,这将导致计算数据类型为 float16
或 bfloat16
,变量数据类型为 float32
。通常,只有在使用混合精度时才需要与数据类型策略交互,混合精度是指使用 float16 或 bfloat16 进行计算,使用 float32 进行变量存储。这就是 API 名称中出现 mixed_precision
的原因。可以通过将 "mixed_float16"
或 "mixed_bfloat16"
传递给 keras.mixed_precision.set_dtype_policy()
来启用混合精度。
>>> keras.config.set_dtype_policy("mixed_float16")
>>> layer1 = keras.layers.Dense(10)
>>> layer1.dtype_policy # layer1 will automatically use mixed precision
<DTypePolicy "mixed_float16">
>>> # Can optionally override layer to use float32
>>> # instead of mixed precision.
>>> layer2 = keras.layers.Dense(10, dtype="float32")
>>> layer2.dtype_policy
<DTypePolicy "float32">
>>> # Set policy back to initial float32.
>>> keras.config.set_dtype_policy('float32')
在上面的示例中,将 dtype="float32"
传递给层等效于传递 dtype=keras.config.DTypePolicy("float32")
。通常,将数据类型策略名称传递给层等效于传递相应的策略,因此无需显式构造 DTypePolicy
对象。
dtype_policy
函数keras.config.dtype_policy()
返回当前默认数据类型策略对象。
set_dtype_policy
函数keras.config.set_dtype_policy(policy)
全局设置默认数据类型策略。
示例
>>> keras.config.set_dtype_policy("mixed_float16")