KerasCV

KerasCV

Star

KerasCV 是一个包含模块化计算机视觉组件的库,可以与 TensorFlow、JAX 或 PyTorch 原生配合使用。基于 Keras 3 构建,这些模型、层、指标、回调等可以在任何框架中进行训练和序列化,并在另一个框架中重复使用,无需代价高昂的迁移。

可以将 KerasCV 理解为 Keras API 的水平扩展:这些组件是新的 Keras 一级对象,其专业性太强,无法添加到核心 Keras 中。它们与核心 Keras API 具有相同的抛光度和向后兼容性保证,并由 Keras 团队维护。

我们的 API 帮助完成常见计算机视觉任务,例如数据增强、分类、目标检测、分割、图像生成等。应用计算机视觉工程师可以利用 KerasCV 快速构建适用于所有这些常见任务的、生产级的、最先进的训练和推理管道。

安装

KerasCV 支持 Keras 2 和 Keras 3。我们建议所有新用户使用 Keras 3,因为它允许将 KerasCV 模型和层与 JAX、TensorFlow 和 PyTorch 一起使用。

Keras 2 安装

要安装带 Keras 2 的最新 KerasCV 版本,只需运行

pip install --upgrade keras-cv tensorflow

Keras 3 安装

目前有两种方法可以安装带 KerasCV 的 Keras 3。要安装 KerasCV 和 Keras 3 的稳定版本,您应该在安装 KerasCV **之后** 安装 Keras 3。这是 TensorFlow 固定到 Keras 2 期间的一个临时步骤,在 TensorFlow 2.16 之后将不再需要。

pip install --upgrade keras-cv tensorflow
pip install --upgrade keras

要安装 KerasCV 和 Keras 的最新 nightly 版本更改,可以使用我们的 nightly 包。

pip install --upgrade keras-cv-nightly tf-nightly

注意:Keras 3 无法与 TensorFlow 2.14 或更早版本一起使用。

有关一般安装 Keras 以及与不同框架的兼容性的更多信息,请参阅Keras 入门指南

快速入门

import os
os.environ["KERAS_BACKEND"] = "tensorflow"  # Or "jax" or "torch"!

import tensorflow as tf
import keras_cv
import tensorflow_datasets as tfds
import keras

# Create a preprocessing pipeline with augmentations
BATCH_SIZE = 16
NUM_CLASSES = 3
augmenter = keras_cv.layers.Augmenter(
    [
        keras_cv.layers.RandomFlip(),
        keras_cv.layers.RandAugment(value_range=(0, 255)),
        keras_cv.layers.CutMix(),
    ],
)

def preprocess_data(images, labels, augment=False):
    labels = tf.one_hot(labels, NUM_CLASSES)
    inputs = {"images": images, "labels": labels}
    outputs = inputs
    if augment:
        outputs = augmenter(outputs)
    return outputs['images'], outputs['labels']

train_dataset, test_dataset = tfds.load(
    'rock_paper_scissors',
    as_supervised=True,
    split=['train', 'test'],
)
train_dataset = train_dataset.batch(BATCH_SIZE).map(
    lambda x, y: preprocess_data(x, y, augment=True),
        num_parallel_calls=tf.data.AUTOTUNE).prefetch(
            tf.data.AUTOTUNE)
test_dataset = test_dataset.batch(BATCH_SIZE).map(
    preprocess_data, num_parallel_calls=tf.data.AUTOTUNE).prefetch(
        tf.data.AUTOTUNE)

# Create a model using a pretrained backbone
backbone = keras_cv.models.EfficientNetV2Backbone.from_preset(
    "efficientnetv2_b0_imagenet"
)
model = keras_cv.models.ImageClassifier(
    backbone=backbone,
    num_classes=NUM_CLASSES,
    activation="softmax",
)
model.compile(
    loss='categorical_crossentropy',
    optimizer=keras.optimizers.Adam(learning_rate=1e-5),
    metrics=['accuracy']
)

# Train your model
model.fit(
    train_dataset,
    validation_data=test_dataset,
    epochs=8,
)

免责声明

KerasCV 通过 keras_cv.models API 提供对预训练模型的访问。这些预训练模型按“现状”提供,不提供任何形式的保证或条件。以下基础模型由第三方提供,并受单独许可证的约束:StableDiffusion、Vision Transfomer

引用 KerasCV

如果 KerasCV 有助于您的研究,我们感谢您的引用。以下是 BibTeX 条目

@misc{wood2022kerascv,
  title={KerasCV},
  author={Wood, Luke and Tan, Zhenyu and Stenbit, Ian and Bischof, Jonathan and Zhu, Scott and Chollet, Fran\c{c}ois and Sreepathihalli, Divyashree and Sampath, Ramesh and others},
  year={2022},
  howpublished={\url{https://github.com/keras-team/keras-cv}},
}