作者: Jyotinder Singh
创建日期 2025/10/14
最后修改日期 2025/10/14
描述: Keras 和 KerasHub 中 INT8 量化的完整指南。
量化将权重和激活的数值精度降低,以减少内存使用并常常加速推理,但会以牺牲少量精度为代价。从 float32 切换到 float16 会使内存需求减半;从 float32 到 INT8 的内存占用大约是 4 倍(与 float16 相比大约是 2 倍)。在具有低精度内核的硬件上(例如 NVIDIA Tensor Cores),这还可以提高吞吐量和降低延迟。实际的收益取决于您的后端和设备。
量化将实数值映射到带有比例因子的 8 位整数
[-128, 127](256 个级别)。wa_max = max(abs(w))。s = (2 * a_max) / 256。q = clip(round(w / s), -128, 127)(存储为 INT8)并保留 s。q 和 s 来即时重构有效的权重(w ≈ s · q)或将 s 折叠到 matmul/conv 中以提高效率。float32 相比减少了约 4 倍,改善了缓存行为并减少了内存停顿;这通常比增加原始 FLOPs 的帮助更大。float16 下可以保持接近 FP 的精度;INT8 可能会引入轻微的下降(通常取决于任务/模型/数据,约 1-5%)。务必在您自己的数据集上进行验证。本指南将介绍如何在 Keras 中使用 8 位整数训练后量化 (PTQ)
我们构建一个小的函数式模型,捕获基线输出,将其原地量化为 INT8,然后使用 MSE 度量来比较输出。
import os
import numpy as np
import keras
from keras import layers
# Create a random number generator.
rng = np.random.default_rng()
# Create a simple functional model.
inputs = keras.Input(shape=(10,))
x = layers.Dense(32, activation="relu")(inputs)
outputs = layers.Dense(1, name="target")(x)
model = keras.Model(inputs, outputs)
# Compile and train briefly to materialize meaningful weights.
model.compile(optimizer="adam", loss="mse")
x_train = rng.random((256, 10)).astype("float32")
y_train = rng.random((256, 1)).astype("float32")
model.fit(x_train, y_train, epochs=1, batch_size=32, verbose=0)
# Sample inputs for evaluation.
x_eval = rng.random((32, 10)).astype("float32")
# Baseline (FP) outputs.
y_fp32 = model(x_eval)
# Quantize the model in-place to INT8.
model.quantize("int8")
# INT8 outputs after quantization.
y_int8 = model(x_eval)
# Compute a simple MSE between FP and INT8 outputs.
mse = keras.ops.mean(keras.ops.square(y_fp32 - y_int8))
print("Full-Precision vs INT8 MSE:", float(mse))
Full-Precision vs INT8 MSE: 4.982496648153756e-06
正如低 MSE 值所示,INT8 量化模型产生的输出非常接近原始 FP32 模型。
您可以使用标准的 Keras 保存和加载 API 来处理量化模型。保存为 .keras 并重新加载时,量化会得到保留。
# Save the quantized model and reload to verify round-trip.
model.save("int8.keras")
int8_reloaded = keras.saving.load_model("int8.keras")
y_int8_reloaded = int8_reloaded(x_eval)
roundtrip_mse = keras.ops.mean(keras.ops.square(y_int8 - y_int8_reloaded))
print("MSE (INT8 vs reloaded-INT8):", float(roundtrip_mse))
MSE (INT8 vs reloaded-INT8): 0.0
所有 KerasHub 模型都支持用于训练后量化的 .quantize(...) API,并且遵循与上述相同的流程。
在此示例中,我们将
from keras_hub.models import Gemma3CausalLM
# Load from Gemma3 preset
gemma3 = Gemma3CausalLM.from_preset("gemma3_1b")
# Generate text for a single prompt
output = gemma3.generate("Keras is a", max_length=50)
print("Full-precision output:", output)
# Save FP32 Gemma3 model for size comparison.
gemma3.save_to_preset("gemma3_fp32")
# Quantize in-place to INT8 and generate again
gemma3.quantize("int8")
output = gemma3.generate("Keras is a", max_length=50)
print("Quantized output:", output)
# Save INT8 Gemma3 model
gemma3.save_to_preset("gemma3_int8")
# Reload and compare outputs
gemma3_int8 = Gemma3CausalLM.from_preset("gemma3_int8")
output = gemma3_int8.generate("Keras is a", max_length=50)
print("Quantized reloaded output:", output)
# Compute storage savings
def bytes_to_mib(n):
return n / (1024**2)
gemma_fp32_size = os.path.getsize("gemma3_fp32/model.weights.h5")
gemma_int8_size = os.path.getsize("gemma3_int8/model.weights.h5")
gemma_reduction = 100.0 * (1.0 - (gemma_int8_size / max(gemma_fp32_size, 1)))
print(f"Gemma3: FP32 file size: {bytes_to_mib(gemma_fp32_size):.2f} MiB")
print(f"Gemma3: INT8 file size: {bytes_to_mib(gemma_int8_size):.2f} MiB")
print(f"Gemma3: Size reduction: {gemma_reduction:.1f}%")
Full-precision output: Keras is a deep learning library for Python. It is a high-level API for neural networks. It is a Python library for deep learning. It is a library for deep learning. It is a library for deep learning. It is a
Quantized output: Keras is a Python library for deep learning. It is a high-level API for building deep learning models. It is designed to be easy
Quantized reloaded output: Keras is a Python library for deep learning. It is a high-level API for building deep learning models. It is designed to be easy
Gemma3: FP32 file size: 3815.32 MiB
Gemma3: INT8 file size: 957.81 MiB
Gemma3: Size reduction: 74.9%
build() 或前向传播)。