开发者指南 / 在 Keras 中编写量化兼容层

在 Keras 中编写量化兼容层

作者: Jyotinder Singh
创建日期 2025/10/16
最后修改日期 2025/10/16
描述: 编写量化兼容 Keras 层的完整指南。

在 Colab 中查看 GitHub 源代码


什么是量化兼容层?

Keras 允许您通过调用 `layer.quantize(...)` 或 `model.quantize(...)` API 来通过训练后量化 (PTQ) 优化模型。Keras 提供了一个可扩展的框架来定义量化兼容层。这允许您创建自定义层,这些层可以集成到量化框架中,可以量化为 INT8 或 INT4,并且可以保存/加载量化元数据。

量化兼容层需要实现几个钩子,以便它可以

  • 将其变量切换到量化表示。
  • 在推理时使用量化感知的前向路径。
  • 将量化元数据与模型一起保存和加载。

在本指南中,我们将实现一个简单的层,它支持 INT8 PTQ。相同的模式可以推广到 INT4 量化和 FP8 混合精度训练。


您将实现的钩子

至少,您的层应该定义

  • quantize(mode, **kwargs): 将现有变量转换为量化形式并切换数据类型策略
  • _int8_build(...): 分配层所需的 INT8 变量
  • _int8_call(inputs, training=None): 最小的 INT8 前向路径

我们将为名为 `SimpleScale` 的非常小的层实现这些,它通过一个可训练的每特征向量(最后一个维度的元素级缩放)来乘以输入。相同的模式可以推广到更复杂的层。


编写一个简单的量化兼容层

我们从一个学习每特征乘法的微小层开始。全精度路径只是计算 `y = x * w`。我们将逐步添加量化钩子。

import numpy as np
import keras
from keras import ops, quantizers, dtype_policies
from keras.layers import Layer, Input


class SimpleScale(Layer):
    """A layer that learns a per-feature scaling factor."""

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def build(self, input_shape):
        input_dim = input_shape[-1]
        self._kernel = self.add_weight(
            name="kernel",
            shape=(input_dim,),
            initializer="random_uniform",
        )

    def call(self, inputs, training=None):
        return ops.multiply(inputs, self._kernel)

`quantize()` 方法

PTQ 是一次性重写。在训练或加载 FP32 层后,调用 `layer.quantize("int8")`。该层应该

  1. 读取其现有的全精度变量(例如 `self._kernel`)。
  2. 将它们量化为 INT8 值加上一个量化标度。
  3. 用 INT8 存储替换全精度变量并分配量化数据。
  4. 将 `dtype_policy` 切换到量化变体(例如 `int8_from_float32`)。
def quantize(self, mode, **kwargs):
    if mode != "int8":
        raise NotImplementedError(f"Unsupported quantization mode: {mode}")

    quantized_kernel, scale = quantizers.abs_max_quantize(
        self._kernel, axis=0, dtype="int8", to_numpy=True
    )
    scale = ops.squeeze(scale, axis=0)

    kernel_shape = self._kernel.shape

    del self._kernel

    # Allocate INT8 variables. Discussed in the next section.
    self._int8_build(kernel_shape)

    self._kernel.assign(quantized_kernel)
    self.scale.assign(scale)

    # `_is_quantized` should be set before changing dtype policy to inform
    # the setter that quantized variables are initialized.
    self._is_quantized = True

    if self.dtype_policy.quantization_mode is None:
        policy = dtype_policies.get(f"{mode}_from_{self.dtype_policy.name}")
        self.dtype_policy = policy

注意

  1. `quantize(...)` 方法应验证 `mode`,并在模式不受支持时引发 `NotImplementedError`。
  2. 确保您的 `quantize(...)` 根据先前的策略设置量化数据类型策略,例如 `int8_from_float32` 或 `int8_from_bfloat16`。这确保了层的 `quantization_mode` 被正确设置。

  3. 应在更改数据类型策略之前设置 `_is_quantized` 标志,以告知设置器已初始化量化变量。

`_int8_build(...)` 方法

此 `int8_build(...)` 方法从 `quantize(...)` 调用,用于初始化 INT8 变量。它应该分配

  • `self._kernel` 作为形状为 `(input_dim,)` 的 INT8 向量(与原始全精度核形状相同)。
  • `self.scale` 作为层变量数据类型中的标量量化标度,在这种情况下为 FP32。
def _int8_build(self, kernel_shape):
    (input_dim,) = kernel_shape
    self._kernel = self.add_weight(
        name="kernel",
        shape=(input_dim,),
        initializer="zeros",
        dtype="int8",
        trainable=False,
    )
    self.scale = self.add_weight(
        name="scale",
        initializer="ones",
        trainable=False,
    )

注意

  1. INT8 变量应创建为 `trainable=False`,因为量化参数不打算在训练期间更新。后续的微调不应更改这些量化变量。
  2. 如果您支持 INT4 量化,请实现类似的 `_int4_build(...)` 方法,该方法分配打包的 INT4 存储(通常打包到 INT8 中)以及每特征标度。原始解包的维度和打包轴应记录为实例变量,以供调用路径使用。Keras 的 Dense 层中提供了参考实现。

`_int8_call(...)` 方法

`_int8_call(...)` 方法实现了最小的 INT8 前向路径。它使用在 `_int8_build(...)` 中分配的量化变量,并将输出反量化回浮点数。

基础 `keras.Layer` 类会自动调用此方法,而无需手动连接。

INT8 路径镜像浮点计算 `y = x * w`,但执行

  1. 使用量化权重进行元素级乘法。
  2. 通过除以 `scale` 来反量化回浮点数。
def _int8_call(self, inputs, training=None):
    x = ops.multiply(inputs, self._kernel)
    x = ops.divide(x, self.scale)
    return x

包含钩子的完整 `SimpleScale` 类

下面是包含上述所有钩子(`quantize`、`_int8_build`、`_int8_call`)的完整类定义。

class SimpleScale(Layer):
    """A layer that learns a per-feature scaling factor."""

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def build(self, input_shape):
        input_dim = input_shape[-1]
        self._kernel = self.add_weight(
            name="kernel",
            shape=(input_dim,),
            initializer="random_uniform",
        )

    def call(self, inputs, training=None):
        return ops.multiply(inputs, self._kernel)

    def quantize(self, mode, **kwargs):
        if mode != "int8":
            raise NotImplementedError(f"Unsupported quantization mode: {mode}")

        quantized_kernel, scale = quantizers.abs_max_quantize(
            self._kernel, axis=0, dtype="int8", to_numpy=True
        )
        scale = ops.squeeze(scale, axis=0)

        kernel_shape = self._kernel.shape

        del self._kernel

        self._int8_build(kernel_shape)

        self._kernel.assign(quantized_kernel)
        self.scale.assign(scale)

        self._is_quantized = True

        if self.dtype_policy.quantization_mode is None:
            policy = dtype_policies.get(f"{mode}_from_{self.dtype_policy.name}")
            self.dtype_policy = policy

    def _int8_build(self, kernel_shape):
        (input_dim,) = kernel_shape
        self._kernel = self.add_weight(
            name="kernel",
            shape=(input_dim,),
            initializer="zeros",
            dtype="int8",
            trainable=False,
        )
        self.scale = self.add_weight(
            name="scale",
            initializer="ones",
            trainable=False,
        )

    def _int8_call(self, inputs, training=None):
        x = ops.multiply(inputs, self._kernel)
        x = ops.divide(x, self.scale)
        return x

尝试一下:量化并运行前向传播

下面我们构建层,然后量化为 INT8 并再次调用它。

# Sample inputs
rng = np.random.default_rng()
x = rng.random((2, 4)).astype("float32")

layer = SimpleScale()

# Forward pass in float
y_fp = layer(x)

# Quantize to INT8 and run again
layer.quantize("int8")
y_int8 = layer(x)

print("SimpleScale FP32 sample:", y_fp[0].numpy())
print("SimpleScale INT8 sample:", y_int8[0].numpy())
SimpleScale FP32 sample: [-0.01259411  0.00385596  0.0053392  -0.00877095]
SimpleScale INT8 sample: [-0.01256325  0.0038252   0.00535317 -0.00877098]

扩展到 INT4

如果您想支持 INT4 量化,请添加

  • _int4_build(...): 分配一个打包的 4 位存储(通常打包到 int8 中)以及每特征标度
  • _int4_call(...): 在运行时解包并遵循相同的反量化模式
  • quantize("int4"): 使用 `value_range=(-8, 7)` 量化权重,然后打包到 int4 存储中

请参阅 Dense 参考,获取一个完整的打包 int4 示例,包括如何在调用路径中跟踪和使用原始(未打包)维度。


添加序列化支持

Keras 依赖于固定的序列化契约来保存和加载模型。量化使这个契约变得复杂,因为您需要保存和加载的变量取决于量化模式。

该框架为层提供了两个钩子来自定义变量序列化

  • save_own_variables(self, store): 以固定顺序将变量写入 `store`。
  • load_own_variables(self, store): 以相同顺序从 `store` 读取变量。

此外,`build(...)` 方法也应该修改,以便根据 `self.quantization_mode` 的存在(或不存在)来分配正确的变量。

对于这个层,我们只打算支持两种模式(非量化和 INT8),所以序列化契约是

  • 无(无量化):`["kernel"]`
  • INT8:`["kernel", "scale"]`

以下代码实现了所需的钩子;Keras 将在 `model.save(...)` 和 `keras.saving.load_model(...)` 期间调用它们。

def save_own_variables(self, store):
    # Write variables to `store` in a fixed order based on quantization mode.
    # `store` is a key-value mapping provided by Keras during model.save().
    # Values are tensors.
    if not self.built:
        return
    mode = self.quantization_mode
    idx = 0
    if mode is None:
        # Order: _kernel
        store[str(idx)] = self._kernel
    elif mode == "int8":
        # Order: _kernel, scale
        store[str(idx)] = self._kernel
        idx += 1
        store[str(idx)] = self.scale
    else:
        raise ValueError(f"Unsupported quantization mode for save: {mode}")


def load_own_variables(self, store):
    # Read variables from `store` in the same order used by
    # `save_own_variables`. Keras calls this during
    # `keras.saving.load_model(...)`.
    if not self.built:
        return
    mode = self.quantization_mode
    idx = 0
    if mode is None:
        self._kernel.assign(store[str(idx)])
    elif mode == "int8":
        self._kernel.assign(store[str(idx)])
        idx += 1
        self.scale.assign(store[str(idx)])
    else:
        raise ValueError(f"Unsupported quantization mode for load: {mode}")

修改 `build(...)` 方法

build 方法本身也需要了解量化模式。如果正在加载/反序列化一个已保存的量化层,`self.quantization_mode` 将在调用 `build(...)` 之前设置。在这种情况下,我们需要直接分配量化变量而不是全精度变量。

def build(self, input_shape):
    input_dim = input_shape[-1]

    # Quantized build path.
    if self.quantization_mode:
        if self.quantization_mode == "int8":
            self._int8_build((input_dim,))
    else:
        # Regular FP32 build path.
        self._kernel = self.add_weight(
            name="kernel",
            shape=(input_dim,),
            initializer="random_uniform",
        )

带序列化的完整实现

具有序列化支持的完整类如下所示

@keras.saving.register_keras_serializable()
class SimpleScale(Layer):
    """A layer that learns a per-feature scaling factor."""

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def build(self, input_shape):
        input_dim = input_shape[-1]

        if self.quantization_mode:
            if self.quantization_mode == "int8":
                self._int8_build((input_dim,))
        else:
            self._kernel = self.add_weight(
                name="kernel",
                shape=(input_dim,),
                initializer="random_uniform",
            )

    def call(self, inputs, training=None):
        return ops.multiply(inputs, self._kernel)

    def quantize(self, mode, **kwargs):
        if mode != "int8":
            raise NotImplementedError(f"Unsupported quantization mode: {mode}")

        quantized_kernel, scale = quantizers.abs_max_quantize(
            self._kernel, axis=0, dtype="int8", to_numpy=True
        )
        scale = ops.squeeze(scale, axis=0)

        kernel_shape = self._kernel.shape

        del self._kernel

        self._int8_build(kernel_shape)

        self._kernel.assign(quantized_kernel)
        self.scale.assign(scale)

        self._is_quantized = True

        if self.dtype_policy.quantization_mode is None:
            policy = dtype_policies.get(f"{mode}_from_{self.dtype_policy.name}")
            self.dtype_policy = policy

    def _int8_build(self, kernel_shape):
        (input_dim,) = kernel_shape
        self._kernel = self.add_weight(
            name="kernel",
            shape=(input_dim,),
            initializer="zeros",
            dtype="int8",
            trainable=False,
        )
        self.scale = self.add_weight(
            name="scale",
            initializer="ones",
            trainable=False,
        )

    def _int8_call(self, inputs, training=None):
        x = ops.multiply(inputs, self._kernel)
        x = ops.divide(x, self.scale)
        return x

    def save_own_variables(self, store):
        # Write variables to `store` in a fixed order based on quantization mode.
        # `store` is a key-value mapping provided by Keras during model.save(); values are tensors.
        if not self.built:
            return
        mode = self.quantization_mode
        idx = 0
        if mode is None:
            # Order: _kernel
            store[str(idx)] = self._kernel
        elif mode == "int8":
            # Order: _kernel, scale
            store[str(idx)] = self._kernel
            idx += 1
            store[str(idx)] = self.scale
        else:
            raise ValueError(f"Unsupported quantization mode for save: {mode}")

    def load_own_variables(self, store):
        # Read variables from `store` in the same order used by `save_own_variables`.
        # Keras calls this during `keras.saving.load_model(...)`.
        if not self.built:
            return
        mode = self.quantization_mode
        idx = 0
        if mode is None:
            self._kernel.assign(store[str(idx)])
        elif mode == "int8":
            self._kernel.assign(store[str(idx)])
            idx += 1
            self.scale.assign(store[str(idx)])
        else:
            raise ValueError(f"Unsupported quantization mode for load: {mode}")

注意

需要 `@keras.saving.register_keras_serializable()` 装饰器来注册类以进行序列化。


尝试一下:量化、保存和加载

model = keras.Sequential([Input(shape=(4,)), SimpleScale()])
model.build((None, 4))

# Quantize to INT8.
model.quantize("int8")
y_int8 = model(x)
print("SimpleScale INT8 sample:", y_int8[0].numpy())

# Save and load the quantized model.
model.save("simplescale_int8.keras")
loaded = keras.saving.load_model("simplescale_int8.keras")

y_loaded = loaded(x)
print("Loaded INT8 sample:", y_loaded[0].numpy())
SimpleScale INT8 sample: [ 0.01568618 -0.00546078  0.00163636  0.00331613]
Loaded INT8 sample: [ 0.01568618 -0.00546078  0.00163636  0.00331613]

实用技巧

以下是您可以重复使用的具体模式,用于使您自己的层对 PTQ 友好。

  • 构建时与调用时职责
    • 在 `build(...)` 中,如果设置了 `self.quantization_mode`:分配量化变量并跳过分配浮点核以避免重复。
  • 记录调用路径所需的任何元数据,例如 INT4
    • 您打包的轴(例如 `_int4_pack_axis`)。
    • 该轴上的原始(未打包)长度(例如 `_original_input_dim` 或 `_original_length_along_pack_axis`)。
  • 在量化调用钩子中,尽可能使用量化缓冲区计算,并在最后反量化回浮点数。这使您能够利用优化的低精度内核(例如 cuBLAS INT8 GEMM)。
  • INT4 特性(打包的 nibble)
    • 将 INT4 值量化到范围 [-8, 7](仍为 dtype int8),然后使用 `quantizers.pack_int4(..., axis=pack_axis)` 将每两个 4 位整数打包到每个字节。
    • 使用 `dtype="int8"` 存储打包的核。在调用路径中使用 `quantizers.unpack_int4(packed, orig_len, axis=pack_axis)` 进行即时解包。
    • 保留原始长度和打包轴,以便您可以为 LoRA、梯度和序列化进行解包。
  • 输入量化和广播
    • 在前向路径中,使用 `outputs /= (inputs_scale * kernel_scale)` 反量化输出;确保两个标度都广播到输出形状。
  • 数据类型策略生命周期
    • 在 `quantize(mode)` 期间:删除 FP32 变量,分配量化变量,分配值,然后设置 `self._is_quantized = True`,然后更改数据类型策略。
    • 仅当当前策略的 `quantization_mode is None` 时才更改策略,以避免无限循环。
  • 序列化契约
    • 提供固定的变量序列化逻辑,以便保存/加载是确定性的。
    • 为每种模式以固定顺序写入变量(例如,None:[kernel, bias],`"int8"`:[kernel, bias, kernel_scale],`"int4"`:[kernel, bias, kernel_scale])。
  • 验证和错误处理
    • 尽早验证 `mode`,并为不支持的模式引发 `NotImplementedError`。
    • 量化后,运行一个小的烟雾测试,并在反量化后断言输出与 FP32 路径匹配且值在合理容差范围内。
  • 性能卫生
    • 避免重复转换热路径;预先计算尽可能多的信息,并保持前向传递钩子轻量级。
    • 保持量化缓冲区 `trainable=False` 并优先使用向量化操作。

有关更高级的模式,请参考 DenseEinsumDense 参考实现。