作者: Laxma Reddy Patlolla, Divyashree Sreepathihalli
创建日期 2025/06/17
最后修改日期 2025/06/23
描述: 如何加载和运行 HuggingFace Hub 上托管的 KerasHub 模型检查点进行推理。
KerasHub 内置了 HuggingFace 的 .safetensors 模型的转换器。因此,从 HuggingFace 加载模型权重与使用 KerasHub 自有预设模型一样简单。
KerasHub 通过其内置的转换器简化了 HuggingFace Transformers 模型的使用。这些转换器会自动处理将 HuggingFace 模型检查点转换为与 Keras 生态系统兼容的格式的过程。这意味着您只需几行代码,即可将 HuggingFace Hub 上的各种预训练模型无缝加载到 KerasHub 中。
使用 KerasHub 转换器的主要优势
幸运的是,所有这些都发生在后台,因此您可以专注于使用模型,而无需管理转换过程!
在开始之前,请确保您已安装必要的库。您主要需要 keras 和 keras_hub。
注意: 在导入 Keras 后更改后端可能无法按预期工作。确保在脚本开头设置了 KERAS_BACKEND。同样,在 Colab 之外工作时,您可以使用 os.environ["HF_TOKEN"] = "<YOUR_HF_TOKEN>" 来向 HuggingFace 进行身份验证。在使用 Google Colab 时,将您的 HF_TOKEN 设置为“Colab 秘密”。
import os
os.environ["KERAS_BACKEND"] = "jax" # "tensorflow" or "torch"
import keras
import keras_hub
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
E0000 00:00:1750715194.841608 7034 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1750715194.846143 7034 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1750715194.857298 7034 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1750715194.857310 7034 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1750715194.857312 7034 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1750715194.857313 7034 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
为了在经济实惠的硬件上进行推理和训练,您可以通过 keras.config 按如下方式配置模型来调整模型的精度:
import keras
keras.config.set_dtype_policy("bfloat16")
KerasHub 允许您轻松加载 HuggingFace Transformers 的模型。以下是如何加载 Gemma 因果语言模型的示例。在这种特定情况下,您需要同意 HuggingFace 上的 Google 许可才能下载模型权重。
# not a keras checkpoint, it is a HF transformer checkpoint
gemma_lm = keras_hub.models.GemmaCausalLM.from_preset("hf://google/gemma-2b")
normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.
让我们尝试运行一些推理
gemma_lm.generate("I want to say", max_length=30)
'I want to say thank you to the staff at the <strong><em><u><strong><em><u><strong><em><u><strong><em><u><strong><em><u><strong><em>'
model.fit(...) API 微调 Gemma Transformer 检查点加载 HuggingFace 权重后,您可以像使用任何其他 KerasHub 模型一样使用实例化的模型。例如,您可以使用以下方法在自己的数据上微调模型:
features = ["The quick brown fox jumped.", "I forgot my homework."]
gemma_lm.fit(x=features, batch_size=2)
1/1 ━━━━━━━━━━━━━━━━━━━━ 50s 50s/step - loss: 0.0342 - sparse_categorical_accuracy: 0.1538
<keras.src.callbacks.history.History at 0x7435981f3800>
要存储和共享您的微调模型,KerasHub 可以轻松地使用标准方法进行保存或上传。您可以通过熟悉的命令来完成此操作,例如:
HF_USERNAME = "<YOUR_HF_USERNAME>" # provide your hf username
gemma_lm.save_to_preset("./gemma-2b-finetuned")
keras_hub.upload_preset(f"hf://{HF_USERNAME}/gemma-2b-finetune", "./gemma-2b-finetuned")
通过上传您的预设,您就可以从任何地方使用它进行加载:loaded_model = keras_hub.models.GemmaCausalLM.from_preset("hf://YOUR_HF_USERNAME/gemma-2b-finetuned")
有关上传模型的全面、分步指南,请参阅 KerasHub 官方上传文档。您可以在此处找到所有详细信息:KerasHub 上传指南
通过集成 HuggingFace Transformers,KerasHub 大大扩展了您对预训练模型的访问。Hugging Face Hub 现在托管着超过 750,000 个模型检查点,涵盖 NLP、计算机视觉、音频等各种领域。其中,目前约有 400,000 个模型与 KerasHub 兼容,让您能够为您的项目访问海量且多样化的最先进架构。
使用 KerasHub,您可以:
这种无缝访问使您能够使用 Keras 构建更强大、更复杂的 AI 应用程序。
Keras 3,以及 KerasHub,专为多框架兼容性而设计。这意味着您可以使用 JAX、TensorFlow 和 PyTorch 等不同的后端框架运行您的模型。这种灵活性使您能够:
要使用 JAX 试验模型,您可以通过将 Keras 的后端设置为 JAX 来实现。在模型构建之前切换 Keras 的后端,并确保您的环境已连接到 TPU 运行时。然后,Keras 将自动利用 JAX 的 TPU 支持,使您的模型能够在 TPU 硬件上高效训练,而无需更改其他代码。
import os
os.environ["KERAS_BACKEND"] = "jax"
gemma_lm = keras_hub.models.GemmaCausalLM.from_preset("hf://google/gemma-2b")
以下是使用 Llama 的示例:将 PyTorch Hugging Face Transformer 检查点加载到 KerasHub,并在 JAX 后端上运行。
import os
os.environ["KERAS_BACKEND"] = "jax"
from keras_hub.models import Llama3CausalLM
# Get the model
causal_lm = Llama3CausalLM.from_preset("hf://NousResearch/Hermes-2-Pro-Llama-3-8B")
prompts = [
"""<|im_start|>system
You are a sentient, superintelligent artificial general intelligence, here to teach and assist me.<|im_end|>
<|im_start|>user
Write a short story about Goku discovering kirby has teamed up with Majin Buu to destroy the world.<|im_end|>
<|im_start|>assistant""",
]
# Generate from the model
causal_lm.generate(prompts, max_length=30)[0]
'system\nYou are a sentient, superintelligent artificial general intelligence, here to teach and assist me.\nuser\nWrite a'
下表总结了 HuggingFace 的 Transformers 库与 KerasHub 的详细比较
| 特征 | HF Transformers | KerasHub |
|---|---|---|
| 支持的框架 | PyTorch | JAX、PyTorch、TensorFlow |
| 训练器 | HF Trainer | Keras model.fit(...) — 支持几乎所有功能,如分布式训练、学习率调度、优化器选择等。 |
| 分词器 | AutoTokenizer |
KerasHub Tokenizers |
| 自动类 | auto 关键字 |
KerasHub 自动检测特定任务的类 |
| 模型加载 | AutoModel.from_pretrained() |
keras_hub.models.<Task>.from_preset()KerasHub 使用特定任务的类(例如 CausalLM、Classifier、Backbone)和一个 from_preset() 方法来加载预训练模型,这类似于 HuggingFace 的方法。支持 HF URL、Kaggle URL 和本地目录 |
| 模型保存 | model.save_pretrained()tokenizer.save_pretrained() |
model.save_to_preset() — 将模型(包括分词器/预处理器)保存到本地目录(预设)。保存了重新加载或上传所需的所有组件。 |
| 模型上传 | 将权重上传到 HF 平台 | KerasHub 上传指南 Hugging Face 上的 Keras |
| 权重文件分片 | 权重文件分片 | 大型模型权重被分片以实现高效上传/下载 |
| PEFT | 使用HuggingFace PEFT | 内置 LoRA 支持backbone.enable_lora(rank=n)backbone.save_lora_weights(filepath)backbone.load_lora_weights(filepath) |
| 核心模型抽象 | PreTrainedModel、AutoModel、特定任务的模型 |
Backbone、Preprocessor、Task |
| 模型配置 | PretrainedConfig:模型配置的基类 |
配置存储在预设目录中的多个 JSON 文件中:config.json、preprocessor.json、task.json、tokenizer.json 等。 |
| 预处理 | 分词器/预处理器通常单独处理,然后传递给模型 | 内置于特定任务的模型中 |
| 混合精度训练 | 通过训练参数 | Keras 全局策略设置 |
| 与 SafeTensors 的兼容性 | 默认权重格式 | 在 HF 上 770,000 多个 SafeTensors 模型中,那些在 KerasHub 中具有匹配架构的模型可以使用 keras_hub.models.X.from_preset() 加载。 |
快去尝试加载其他模型权重吧!您可以在 HuggingFace 上找到更多选项,并使用 from_preset("hf://<namespace>/<model-name>") 来使用它们。
尽情探索!