KerasCV 是一个包含模块化计算机视觉组件的库,可以与 TensorFlow、JAX 或 PyTorch 原生配合使用。基于 Keras 3 构建,这些模型、层、指标、回调等可以在任何框架中进行训练和序列化,并在另一个框架中重复使用,无需代价高昂的迁移。
可以将 KerasCV 理解为 Keras API 的水平扩展:这些组件是新的 Keras 一级对象,其专业性太强,无法添加到核心 Keras 中。它们与核心 Keras API 具有相同的抛光度和向后兼容性保证,并由 Keras 团队维护。
我们的 API 帮助完成常见计算机视觉任务,例如数据增强、分类、目标检测、分割、图像生成等。应用计算机视觉工程师可以利用 KerasCV 快速构建适用于所有这些常见任务的、生产级的、最先进的训练和推理管道。
KerasCV 支持 Keras 2 和 Keras 3。我们建议所有新用户使用 Keras 3,因为它允许将 KerasCV 模型和层与 JAX、TensorFlow 和 PyTorch 一起使用。
要安装带 Keras 2 的最新 KerasCV 版本,只需运行
pip install --upgrade keras-cv tensorflow
目前有两种方法可以安装带 KerasCV 的 Keras 3。要安装 KerasCV 和 Keras 3 的稳定版本,您应该在安装 KerasCV **之后** 安装 Keras 3。这是 TensorFlow 固定到 Keras 2 期间的一个临时步骤,在 TensorFlow 2.16 之后将不再需要。
pip install --upgrade keras-cv tensorflow
pip install --upgrade keras
要安装 KerasCV 和 Keras 的最新 nightly 版本更改,可以使用我们的 nightly 包。
pip install --upgrade keras-cv-nightly tf-nightly
注意:Keras 3 无法与 TensorFlow 2.14 或更早版本一起使用。
有关一般安装 Keras 以及与不同框架的兼容性的更多信息,请参阅Keras 入门指南。
import os
os.environ["KERAS_BACKEND"] = "tensorflow" # Or "jax" or "torch"!
import tensorflow as tf
import keras_cv
import tensorflow_datasets as tfds
import keras
# Create a preprocessing pipeline with augmentations
BATCH_SIZE = 16
NUM_CLASSES = 3
augmenter = keras_cv.layers.Augmenter(
[
keras_cv.layers.RandomFlip(),
keras_cv.layers.RandAugment(value_range=(0, 255)),
keras_cv.layers.CutMix(),
],
)
def preprocess_data(images, labels, augment=False):
labels = tf.one_hot(labels, NUM_CLASSES)
inputs = {"images": images, "labels": labels}
outputs = inputs
if augment:
outputs = augmenter(outputs)
return outputs['images'], outputs['labels']
train_dataset, test_dataset = tfds.load(
'rock_paper_scissors',
as_supervised=True,
split=['train', 'test'],
)
train_dataset = train_dataset.batch(BATCH_SIZE).map(
lambda x, y: preprocess_data(x, y, augment=True),
num_parallel_calls=tf.data.AUTOTUNE).prefetch(
tf.data.AUTOTUNE)
test_dataset = test_dataset.batch(BATCH_SIZE).map(
preprocess_data, num_parallel_calls=tf.data.AUTOTUNE).prefetch(
tf.data.AUTOTUNE)
# Create a model using a pretrained backbone
backbone = keras_cv.models.EfficientNetV2Backbone.from_preset(
"efficientnetv2_b0_imagenet"
)
model = keras_cv.models.ImageClassifier(
backbone=backbone,
num_classes=NUM_CLASSES,
activation="softmax",
)
model.compile(
loss='categorical_crossentropy',
optimizer=keras.optimizers.Adam(learning_rate=1e-5),
metrics=['accuracy']
)
# Train your model
model.fit(
train_dataset,
validation_data=test_dataset,
epochs=8,
)
KerasCV 通过 keras_cv.models
API 提供对预训练模型的访问。这些预训练模型按“现状”提供,不提供任何形式的保证或条件。以下基础模型由第三方提供,并受单独许可证的约束:StableDiffusion、Vision Transfomer
如果 KerasCV 有助于您的研究,我们感谢您的引用。以下是 BibTeX 条目
@misc{wood2022kerascv,
title={KerasCV},
author={Wood, Luke and Tan, Zhenyu and Stenbit, Ian and Bischof, Jonathan and Zhu, Scott and Chollet, Fran\c{c}ois and Sreepathihalli, Divyashree and Sampath, Ramesh and others},
year={2022},
howpublished={\url{https://github.com/keras-team/keras-cv}},
}