作者: Jyotinder Singh
创建日期 2025/10/16
最后修改日期 2025/10/16
描述: 编写量化兼容 Keras 层的完整指南。
Keras 允许您通过调用 `layer.quantize(...)` 或 `model.quantize(...)` API 来通过训练后量化 (PTQ) 优化模型。Keras 提供了一个可扩展的框架来定义量化兼容层。这允许您创建自定义层,这些层可以集成到量化框架中,可以量化为 INT8 或 INT4,并且可以保存/加载量化元数据。
量化兼容层需要实现几个钩子,以便它可以
在本指南中,我们将实现一个简单的层,它支持 INT8 PTQ。相同的模式可以推广到 INT4 量化和 FP8 混合精度训练。
至少,您的层应该定义
quantize(mode, **kwargs): 将现有变量转换为量化形式并切换数据类型策略_int8_build(...): 分配层所需的 INT8 变量_int8_call(inputs, training=None): 最小的 INT8 前向路径我们将为名为 `SimpleScale` 的非常小的层实现这些,它通过一个可训练的每特征向量(最后一个维度的元素级缩放)来乘以输入。相同的模式可以推广到更复杂的层。
我们从一个学习每特征乘法的微小层开始。全精度路径只是计算 `y = x * w`。我们将逐步添加量化钩子。
import numpy as np
import keras
from keras import ops, quantizers, dtype_policies
from keras.layers import Layer, Input
class SimpleScale(Layer):
"""A layer that learns a per-feature scaling factor."""
def __init__(self, **kwargs):
super().__init__(**kwargs)
def build(self, input_shape):
input_dim = input_shape[-1]
self._kernel = self.add_weight(
name="kernel",
shape=(input_dim,),
initializer="random_uniform",
)
def call(self, inputs, training=None):
return ops.multiply(inputs, self._kernel)
PTQ 是一次性重写。在训练或加载 FP32 层后,调用 `layer.quantize("int8")`。该层应该
def quantize(self, mode, **kwargs):
if mode != "int8":
raise NotImplementedError(f"Unsupported quantization mode: {mode}")
quantized_kernel, scale = quantizers.abs_max_quantize(
self._kernel, axis=0, dtype="int8", to_numpy=True
)
scale = ops.squeeze(scale, axis=0)
kernel_shape = self._kernel.shape
del self._kernel
# Allocate INT8 variables. Discussed in the next section.
self._int8_build(kernel_shape)
self._kernel.assign(quantized_kernel)
self.scale.assign(scale)
# `_is_quantized` should be set before changing dtype policy to inform
# the setter that quantized variables are initialized.
self._is_quantized = True
if self.dtype_policy.quantization_mode is None:
policy = dtype_policies.get(f"{mode}_from_{self.dtype_policy.name}")
self.dtype_policy = policy
确保您的 `quantize(...)` 根据先前的策略设置量化数据类型策略,例如 `int8_from_float32` 或 `int8_from_bfloat16`。这确保了层的 `quantization_mode` 被正确设置。
应在更改数据类型策略之前设置 `_is_quantized` 标志,以告知设置器已初始化量化变量。
此 `int8_build(...)` 方法从 `quantize(...)` 调用,用于初始化 INT8 变量。它应该分配
def _int8_build(self, kernel_shape):
(input_dim,) = kernel_shape
self._kernel = self.add_weight(
name="kernel",
shape=(input_dim,),
initializer="zeros",
dtype="int8",
trainable=False,
)
self.scale = self.add_weight(
name="scale",
initializer="ones",
trainable=False,
)
`_int8_call(...)` 方法实现了最小的 INT8 前向路径。它使用在 `_int8_build(...)` 中分配的量化变量,并将输出反量化回浮点数。
基础 `keras.Layer` 类会自动调用此方法,而无需手动连接。
INT8 路径镜像浮点计算 `y = x * w`,但执行
def _int8_call(self, inputs, training=None):
x = ops.multiply(inputs, self._kernel)
x = ops.divide(x, self.scale)
return x
下面是包含上述所有钩子(`quantize`、`_int8_build`、`_int8_call`)的完整类定义。
class SimpleScale(Layer):
"""A layer that learns a per-feature scaling factor."""
def __init__(self, **kwargs):
super().__init__(**kwargs)
def build(self, input_shape):
input_dim = input_shape[-1]
self._kernel = self.add_weight(
name="kernel",
shape=(input_dim,),
initializer="random_uniform",
)
def call(self, inputs, training=None):
return ops.multiply(inputs, self._kernel)
def quantize(self, mode, **kwargs):
if mode != "int8":
raise NotImplementedError(f"Unsupported quantization mode: {mode}")
quantized_kernel, scale = quantizers.abs_max_quantize(
self._kernel, axis=0, dtype="int8", to_numpy=True
)
scale = ops.squeeze(scale, axis=0)
kernel_shape = self._kernel.shape
del self._kernel
self._int8_build(kernel_shape)
self._kernel.assign(quantized_kernel)
self.scale.assign(scale)
self._is_quantized = True
if self.dtype_policy.quantization_mode is None:
policy = dtype_policies.get(f"{mode}_from_{self.dtype_policy.name}")
self.dtype_policy = policy
def _int8_build(self, kernel_shape):
(input_dim,) = kernel_shape
self._kernel = self.add_weight(
name="kernel",
shape=(input_dim,),
initializer="zeros",
dtype="int8",
trainable=False,
)
self.scale = self.add_weight(
name="scale",
initializer="ones",
trainable=False,
)
def _int8_call(self, inputs, training=None):
x = ops.multiply(inputs, self._kernel)
x = ops.divide(x, self.scale)
return x
下面我们构建层,然后量化为 INT8 并再次调用它。
# Sample inputs
rng = np.random.default_rng()
x = rng.random((2, 4)).astype("float32")
layer = SimpleScale()
# Forward pass in float
y_fp = layer(x)
# Quantize to INT8 and run again
layer.quantize("int8")
y_int8 = layer(x)
print("SimpleScale FP32 sample:", y_fp[0].numpy())
print("SimpleScale INT8 sample:", y_int8[0].numpy())
SimpleScale FP32 sample: [-0.01259411 0.00385596 0.0053392 -0.00877095]
SimpleScale INT8 sample: [-0.01256325 0.0038252 0.00535317 -0.00877098]
如果您想支持 INT4 量化,请添加
_int4_build(...): 分配一个打包的 4 位存储(通常打包到 int8 中)以及每特征标度_int4_call(...): 在运行时解包并遵循相同的反量化模式quantize("int4"): 使用 `value_range=(-8, 7)` 量化权重,然后打包到 int4 存储中请参阅 Dense 参考,获取一个完整的打包 int4 示例,包括如何在调用路径中跟踪和使用原始(未打包)维度。
Keras 依赖于固定的序列化契约来保存和加载模型。量化使这个契约变得复杂,因为您需要保存和加载的变量取决于量化模式。
该框架为层提供了两个钩子来自定义变量序列化
save_own_variables(self, store): 以固定顺序将变量写入 `store`。load_own_variables(self, store): 以相同顺序从 `store` 读取变量。此外,`build(...)` 方法也应该修改,以便根据 `self.quantization_mode` 的存在(或不存在)来分配正确的变量。
对于这个层,我们只打算支持两种模式(非量化和 INT8),所以序列化契约是
以下代码实现了所需的钩子;Keras 将在 `model.save(...)` 和 `keras.saving.load_model(...)` 期间调用它们。
def save_own_variables(self, store):
# Write variables to `store` in a fixed order based on quantization mode.
# `store` is a key-value mapping provided by Keras during model.save().
# Values are tensors.
if not self.built:
return
mode = self.quantization_mode
idx = 0
if mode is None:
# Order: _kernel
store[str(idx)] = self._kernel
elif mode == "int8":
# Order: _kernel, scale
store[str(idx)] = self._kernel
idx += 1
store[str(idx)] = self.scale
else:
raise ValueError(f"Unsupported quantization mode for save: {mode}")
def load_own_variables(self, store):
# Read variables from `store` in the same order used by
# `save_own_variables`. Keras calls this during
# `keras.saving.load_model(...)`.
if not self.built:
return
mode = self.quantization_mode
idx = 0
if mode is None:
self._kernel.assign(store[str(idx)])
elif mode == "int8":
self._kernel.assign(store[str(idx)])
idx += 1
self.scale.assign(store[str(idx)])
else:
raise ValueError(f"Unsupported quantization mode for load: {mode}")
build 方法本身也需要了解量化模式。如果正在加载/反序列化一个已保存的量化层,`self.quantization_mode` 将在调用 `build(...)` 之前设置。在这种情况下,我们需要直接分配量化变量而不是全精度变量。
def build(self, input_shape):
input_dim = input_shape[-1]
# Quantized build path.
if self.quantization_mode:
if self.quantization_mode == "int8":
self._int8_build((input_dim,))
else:
# Regular FP32 build path.
self._kernel = self.add_weight(
name="kernel",
shape=(input_dim,),
initializer="random_uniform",
)
具有序列化支持的完整类如下所示
@keras.saving.register_keras_serializable()
class SimpleScale(Layer):
"""A layer that learns a per-feature scaling factor."""
def __init__(self, **kwargs):
super().__init__(**kwargs)
def build(self, input_shape):
input_dim = input_shape[-1]
if self.quantization_mode:
if self.quantization_mode == "int8":
self._int8_build((input_dim,))
else:
self._kernel = self.add_weight(
name="kernel",
shape=(input_dim,),
initializer="random_uniform",
)
def call(self, inputs, training=None):
return ops.multiply(inputs, self._kernel)
def quantize(self, mode, **kwargs):
if mode != "int8":
raise NotImplementedError(f"Unsupported quantization mode: {mode}")
quantized_kernel, scale = quantizers.abs_max_quantize(
self._kernel, axis=0, dtype="int8", to_numpy=True
)
scale = ops.squeeze(scale, axis=0)
kernel_shape = self._kernel.shape
del self._kernel
self._int8_build(kernel_shape)
self._kernel.assign(quantized_kernel)
self.scale.assign(scale)
self._is_quantized = True
if self.dtype_policy.quantization_mode is None:
policy = dtype_policies.get(f"{mode}_from_{self.dtype_policy.name}")
self.dtype_policy = policy
def _int8_build(self, kernel_shape):
(input_dim,) = kernel_shape
self._kernel = self.add_weight(
name="kernel",
shape=(input_dim,),
initializer="zeros",
dtype="int8",
trainable=False,
)
self.scale = self.add_weight(
name="scale",
initializer="ones",
trainable=False,
)
def _int8_call(self, inputs, training=None):
x = ops.multiply(inputs, self._kernel)
x = ops.divide(x, self.scale)
return x
def save_own_variables(self, store):
# Write variables to `store` in a fixed order based on quantization mode.
# `store` is a key-value mapping provided by Keras during model.save(); values are tensors.
if not self.built:
return
mode = self.quantization_mode
idx = 0
if mode is None:
# Order: _kernel
store[str(idx)] = self._kernel
elif mode == "int8":
# Order: _kernel, scale
store[str(idx)] = self._kernel
idx += 1
store[str(idx)] = self.scale
else:
raise ValueError(f"Unsupported quantization mode for save: {mode}")
def load_own_variables(self, store):
# Read variables from `store` in the same order used by `save_own_variables`.
# Keras calls this during `keras.saving.load_model(...)`.
if not self.built:
return
mode = self.quantization_mode
idx = 0
if mode is None:
self._kernel.assign(store[str(idx)])
elif mode == "int8":
self._kernel.assign(store[str(idx)])
idx += 1
self.scale.assign(store[str(idx)])
else:
raise ValueError(f"Unsupported quantization mode for load: {mode}")
需要 `@keras.saving.register_keras_serializable()` 装饰器来注册类以进行序列化。
model = keras.Sequential([Input(shape=(4,)), SimpleScale()])
model.build((None, 4))
# Quantize to INT8.
model.quantize("int8")
y_int8 = model(x)
print("SimpleScale INT8 sample:", y_int8[0].numpy())
# Save and load the quantized model.
model.save("simplescale_int8.keras")
loaded = keras.saving.load_model("simplescale_int8.keras")
y_loaded = loaded(x)
print("Loaded INT8 sample:", y_loaded[0].numpy())
SimpleScale INT8 sample: [ 0.01568618 -0.00546078 0.00163636 0.00331613]
Loaded INT8 sample: [ 0.01568618 -0.00546078 0.00163636 0.00331613]
以下是您可以重复使用的具体模式,用于使您自己的层对 PTQ 友好。
有关更高级的模式,请参考 Dense 和 EinsumDense 参考实现。