KerasHub:预训练模型 / 开发者指南 / 加载 Hugging Face Transformers 检查点

加载 Hugging Face Transformers 检查点

作者: Laxma Reddy Patlolla, Divyashree Sreepathihalli
创建日期 2025/06/17
最后修改日期 2025/06/23
描述: 如何加载和运行 HuggingFace Hub 上托管的 KerasHub 模型检查点进行推理。

在 Colab 中查看 GitHub 源码


简介

KerasHub 内置了 HuggingFace 的 .safetensors 模型的转换器。因此,从 HuggingFace 加载模型权重与使用 KerasHub 自有预设模型一样简单。

KerasHub 内置 HuggingFace Transformers 转换器

KerasHub 通过其内置的转换器简化了 HuggingFace Transformers 模型的使用。这些转换器会自动处理将 HuggingFace 模型检查点转换为与 Keras 生态系统兼容的格式的过程。这意味着您只需几行代码,即可将 HuggingFace Hub 上的各种预训练模型无缝加载到 KerasHub 中。

使用 KerasHub 转换器的主要优势

  • 易于使用:无需手动转换步骤即可加载 HuggingFace 模型。
  • 广泛的兼容性:可访问 HuggingFace Hub 上提供的海量模型。
  • 无缝集成:使用熟悉的 Keras API 进行训练、评估和推理。

幸运的是,所有这些都发生在后台,因此您可以专注于使用模型,而无需管理转换过程!


设置

在开始之前,请确保您已安装必要的库。您主要需要 keraskeras_hub

注意: 在导入 Keras 后更改后端可能无法按预期工作。确保在脚本开头设置了 KERAS_BACKEND。同样,在 Colab 之外工作时,您可以使用 os.environ["HF_TOKEN"] = "<YOUR_HF_TOKEN>" 来向 HuggingFace 进行身份验证。在使用 Google Colab 时,将您的 HF_TOKEN 设置为“Colab 秘密”。

import os

os.environ["KERAS_BACKEND"] = "jax"  # "tensorflow" or  "torch"

import keras
import keras_hub
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
E0000 00:00:1750715194.841608    7034 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1750715194.846143    7034 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1750715194.857298    7034 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1750715194.857310    7034 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1750715194.857312    7034 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1750715194.857313    7034 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.

更改精度

为了在经济实惠的硬件上进行推理和训练,您可以通过 keras.config 按如下方式配置模型来调整模型的精度:

import keras

keras.config.set_dtype_policy("bfloat16")

加载 HuggingFace 模型

KerasHub 允许您轻松加载 HuggingFace Transformers 的模型。以下是如何加载 Gemma 因果语言模型的示例。在这种特定情况下,您需要同意 HuggingFace 上的 Google 许可才能下载模型权重。

# not a keras checkpoint, it is a HF transformer checkpoint

gemma_lm = keras_hub.models.GemmaCausalLM.from_preset("hf://google/gemma-2b")
normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.

让我们尝试运行一些推理

gemma_lm.generate("I want to say", max_length=30)
'I want to say thank you to the staff at the <strong><em><u><strong><em><u><strong><em><u><strong><em><u><strong><em><u><strong><em>'

使用 Keras model.fit(...) API 微调 Gemma Transformer 检查点

加载 HuggingFace 权重后,您可以像使用任何其他 KerasHub 模型一样使用实例化的模型。例如,您可以使用以下方法在自己的数据上微调模型:

features = ["The quick brown fox jumped.", "I forgot my homework."]
gemma_lm.fit(x=features, batch_size=2)
1/1 ━━━━━━━━━━━━━━━━━━━━ 50s 50s/step - loss: 0.0342 - sparse_categorical_accuracy: 0.1538

<keras.src.callbacks.history.History at 0x7435981f3800>

保存并上传新检查点

要存储和共享您的微调模型,KerasHub 可以轻松地使用标准方法进行保存或上传。您可以通过熟悉的命令来完成此操作,例如:

HF_USERNAME = "<YOUR_HF_USERNAME>"  # provide your hf username
gemma_lm.save_to_preset("./gemma-2b-finetuned")
keras_hub.upload_preset(f"hf://{HF_USERNAME}/gemma-2b-finetune", "./gemma-2b-finetuned")

通过上传您的预设,您就可以从任何地方使用它进行加载:loaded_model = keras_hub.models.GemmaCausalLM.from_preset("hf://YOUR_HF_USERNAME/gemma-2b-finetuned")

有关上传模型的全面、分步指南,请参阅 KerasHub 官方上传文档。您可以在此处找到所有详细信息:KerasHub 上传指南

通过集成 HuggingFace Transformers,KerasHub 大大扩展了您对预训练模型的访问。Hugging Face Hub 现在托管着超过 750,000 个模型检查点,涵盖 NLP、计算机视觉、音频等各种领域。其中,目前约有 400,000 个模型与 KerasHub 兼容,让您能够为您的项目访问海量且多样化的最先进架构。

使用 KerasHub,您可以:

  • 利用最先进的模型:轻松尝试来自研究社区和行业的最新架构和预训练权重。
  • 缩短开发时间:利用现有模型,而不是从头开始训练,节省大量时间和计算资源。
  • 增强模型功能:为广泛的任务找到专用模型,从文本生成和翻译到图像分割和对象检测。

这种无缝访问使您能够使用 Keras 构建更强大、更复杂的 AI 应用程序。


使用更广泛的框架

Keras 3,以及 KerasHub,专为多框架兼容性而设计。这意味着您可以使用 JAX、TensorFlow 和 PyTorch 等不同的后端框架运行您的模型。这种灵活性使您能够:

  • 选择最适合您需求的后端:根据性能特征、硬件兼容性(例如,JAX 的 TPUs)或现有团队专业知识选择后端。
  • 互操作性:更容易将 KerasHub 模型集成到可能基于 TensorFlow 或 PyTorch 构建的现有工作流中。
  • 面向未来:在不重写核心模型逻辑的情况下适应不断发展的框架格局。

在 JAX 后端和 TPU 上运行 Transformer 模型

要使用 JAX 试验模型,您可以通过将 Keras 的后端设置为 JAX 来实现。在模型构建之前切换 Keras 的后端,并确保您的环境已连接到 TPU 运行时。然后,Keras 将自动利用 JAX 的 TPU 支持,使您的模型能够在 TPU 硬件上高效训练,而无需更改其他代码。

import os

os.environ["KERAS_BACKEND"] = "jax"
gemma_lm = keras_hub.models.GemmaCausalLM.from_preset("hf://google/gemma-2b")

更多示例

生成

以下是使用 Llama 的示例:将 PyTorch Hugging Face Transformer 检查点加载到 KerasHub,并在 JAX 后端上运行。

import os

os.environ["KERAS_BACKEND"] = "jax"

from keras_hub.models import Llama3CausalLM

# Get the model
causal_lm = Llama3CausalLM.from_preset("hf://NousResearch/Hermes-2-Pro-Llama-3-8B")

prompts = [
    """<|im_start|>system
You are a sentient, superintelligent artificial general intelligence, here to teach and assist me.<|im_end|>
<|im_start|>user
Write a short story about Goku discovering kirby has teamed up with Majin Buu to destroy the world.<|im_end|>
<|im_start|>assistant""",
]

# Generate from the model
causal_lm.generate(prompts, max_length=30)[0]
'system\nYou are a sentient, superintelligent artificial general intelligence, here to teach and assist me.\nuser\nWrite a'

与 Transformers 比较

下表总结了 HuggingFace 的 Transformers 库与 KerasHub 的详细比较

特征 HF Transformers KerasHub
支持的框架 PyTorch JAX、PyTorch、TensorFlow
训练器 HF Trainer Keras model.fit(...) — 支持几乎所有功能,如分布式训练、学习率调度、优化器选择等。
分词器 AutoTokenizer KerasHub Tokenizers
自动类 auto 关键字 KerasHub 自动检测特定任务的类
模型加载 AutoModel.from_pretrained() keras_hub.models.<Task>.from_preset()

KerasHub 使用特定任务的类(例如 CausalLMClassifierBackbone)和一个 from_preset() 方法来加载预训练模型,这类似于 HuggingFace 的方法。

支持 HF URL、Kaggle URL 和本地目录
模型保存 model.save_pretrained()
tokenizer.save_pretrained()
model.save_to_preset() — 将模型(包括分词器/预处理器)保存到本地目录(预设)。保存了重新加载或上传所需的所有组件。
模型上传 将权重上传到 HF 平台 KerasHub 上传指南
Hugging Face 上的 Keras
权重文件分片 权重文件分片 大型模型权重被分片以实现高效上传/下载
PEFT 使用HuggingFace PEFT 内置 LoRA 支持
backbone.enable_lora(rank=n)
backbone.save_lora_weights(filepath)
backbone.load_lora_weights(filepath)
核心模型抽象 PreTrainedModelAutoModel、特定任务的模型 BackbonePreprocessorTask
模型配置 PretrainedConfig:模型配置的基类 配置存储在预设目录中的多个 JSON 文件中:config.jsonpreprocessor.jsontask.jsontokenizer.json 等。
预处理 分词器/预处理器通常单独处理,然后传递给模型 内置于特定任务的模型中
混合精度训练 通过训练参数 Keras 全局策略设置
与 SafeTensors 的兼容性 默认权重格式 在 HF 上 770,000 多个 SafeTensors 模型中,那些在 KerasHub 中具有匹配架构的模型可以使用 keras_hub.models.X.from_preset() 加载。

快去尝试加载其他模型权重吧!您可以在 HuggingFace 上找到更多选项,并使用 from_preset("hf://<namespace>/<model-name>") 来使用它们。

尽情探索!