开发者指南 / KerasHub / 在 KerasHub 中分割任何物体!

在 KerasHub 中分割任何物体!

作者: Tirth Patel,Ian Stenbit,Divyashree Sreepathihalli

创建时间 2024/10/1

最后修改时间 2024/10/1

描述: 使用 KerasHub 中的文本、框和点提示分割任何物体。

在 Colab 中查看 GitHub 源代码


概述

Segment Anything 模型 (SAM) 可以从输入提示(如点或框)生成高质量的目标遮罩,并且它可以用来生成图像中所有物体的遮罩。它已经在一组包含 1100 万张图像和 11 亿个遮罩的数据集上进行了训练,并且在各种分割任务上具有强大的零样本性能。

在本指南中,我们将展示如何使用 KerasHub 对Segment Anything 模型的实现,并展示 TensorFlow 和 JAX 的性能提升是多么强大。

首先,让我们获取所有依赖项和演示用的图像。

!!pip install -Uq git+https://github.com/keras-team/keras-hub.git
!!pip install -Uq keras
!!wget -q https://raw.githubusercontent.com/facebookresearch/segment-anything/main/notebooks/images/truck.jpg
[]

[]

选择你的后端

使用 Keras 3,你可以选择使用你喜欢的后端!

import os

os.environ["KERAS_BACKEND"] = "jax"

import timeit
import numpy as np
import matplotlib.pyplot as plt
import keras
from keras import ops
import keras_hub

辅助函数

让我们定义一些辅助函数来可视化图像、提示和分割结果。

def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)


def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels == 1]
    neg_points = coords[labels == 0]
    ax.scatter(
        pos_points[:, 0],
        pos_points[:, 1],
        color="green",
        marker="*",
        s=marker_size,
        edgecolor="white",
        linewidth=1.25,
    )
    ax.scatter(
        neg_points[:, 0],
        neg_points[:, 1],
        color="red",
        marker="*",
        s=marker_size,
        edgecolor="white",
        linewidth=1.25,
    )


def show_box(box, ax):
    box = box.reshape(-1)
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(
        plt.Rectangle((x0, y0), w, h, edgecolor="green", facecolor=(0, 0, 0, 0), lw=2)
    )


def inference_resizing(image, pad=True):
    # Compute Preprocess Shape
    image = ops.cast(image, dtype="float32")
    old_h, old_w = image.shape[0], image.shape[1]
    scale = 1024 * 1.0 / max(old_h, old_w)
    new_h = old_h * scale
    new_w = old_w * scale
    preprocess_shape = int(new_h + 0.5), int(new_w + 0.5)

    # Resize the image
    image = ops.image.resize(image[None, ...], preprocess_shape)[0]

    # Pad the shorter side
    if pad:
        pixel_mean = ops.array([123.675, 116.28, 103.53])
        pixel_std = ops.array([58.395, 57.12, 57.375])
        image = (image - pixel_mean) / pixel_std
        h, w = image.shape[0], image.shape[1]
        pad_h = 1024 - h
        pad_w = 1024 - w
        image = ops.pad(image, [(0, pad_h), (0, pad_w), (0, 0)])
        # KerasHub now rescales the images and normalizes them.
        # Just unnormalize such that when KerasHub normalizes them
        # again, the padded values map to 0.
        image = image * pixel_std + pixel_mean
    return image

获取预训练的 SAM 模型

我们可以使用 KerasHub 的 from_preset 工厂方法初始化一个已训练的 SAM 模型。这里,我们使用在 SA-1B 数据集上训练的大型 ViT 主干 (sam_huge_sa1b) 来获得高质量的分割遮罩。你也可以使用 sam_large_sa1bsam_base_sa1b 中的一个来获得更好的性能(但会降低分割遮罩的质量)。

model = keras_hub.models.SAMImageSegmenter.from_preset("sam_huge_sa1b")
Downloading from https://www.kaggle.com/api/v1/models/kerashub/sam/keras/sam_huge_sa1b/2/download/config.json...

100%|████████████████████████████████████████████████████| 3.06k/3.06k [00:00<00:00, 6.08MB/s]

Downloading from https://www.kaggle.com/api/v1/models/kerashub/sam/keras/sam_huge_sa1b/2/download/task.json...

100%|████████████████████████████████████████████████████| 5.76k/5.76k [00:00<00:00, 11.0MB/s]

Downloading from https://www.kaggle.com/api/v1/models/kerashub/sam/keras/sam_huge_sa1b/2/download/task.weights.h5...

100%|████████████████████████████████████████████████████| 2.39G/2.39G [00:26<00:00, 95.7MB/s]

Downloading from https://www.kaggle.com/api/v1/models/kerashub/sam/keras/sam_huge_sa1b/2/download/model.weights.h5...

100%|████████████████████████████████████████████████████| 2.39G/2.39G [00:32<00:00, 79.7MB/s]

理解提示

Segment Anything 允许使用点、框和遮罩来提示图像

  1. 点提示是最基本的提示之一:模型尝试根据图像上的一个点来猜测物体。这个点可以是前景点(即所需的分割遮罩包含该点)或背景点(即该点位于所需遮罩之外)。
  2. 另一种提示模型的方法是使用框。给定一个边界框,模型尝试分割其中包含的物体。
  3. 最后,模型也可以使用遮罩本身来提示。例如,这在细化先前预测或已知的分割遮罩的边界时很有用。

使模型极其强大的地方在于能够将以上提示组合起来。点、框和遮罩提示可以以多种不同的方式组合起来,以获得最佳效果。

让我们看看在 KerasHub 中将这些提示传递给 Segment Anything 模型的语义。SAM 模型的输入是一个包含以下键的字典:

  1. "images":要分割的图像批次。必须形状为 (B, 1024, 1024, 3)
  2. "points":点提示的批次。每个点都是一个 (x, y) 坐标,源自图像的左上角。换句话说,每个点都是 (r, c) 的形式,其中 rc 是图像中像素的行和列。必须形状为 (B, N, 2)
  3. "labels":给定点的标签批次。1 代表前景点,0 代表背景点。必须形状为 (B, N)
  4. "boxes":框的批次。请注意,模型每个批次只接受一个框。因此,预期形状为 (B, 1, 2, 2)。每个框都是两个点的集合:框的左上角和右下角。这里的点遵循与点提示相同的语义。这里的第二个维度中的 1 代表框提示的存在。如果缺少框提示,则必须传递形状为 (B, 0, 2, 2) 的占位符输入。
  5. "masks":遮罩的批次。与框提示一样,每张图像只允许一个遮罩提示。如果存在,输入遮罩的形状必须为 (B, 1, 256, 256, 1),如果缺少遮罩提示,则形状为 (B, 0, 256, 256, 1)

占位符提示仅在直接调用模型(即 model(...))时才需要。在调用 predict 方法时,可以省略输入字典中缺少的提示。


点提示

首先,让我们使用点提示分割图像。我们加载图像并将其调整为 (1024, 1024) 形状,这是预训练的 SAM 模型所期望的图像大小。

# Load our image
image = np.array(keras.utils.load_img("truck.jpg"))
image = inference_resizing(image)

plt.figure(figsize=(10, 10))
plt.imshow(ops.convert_to_numpy(image) / 255.0)
plt.axis("on")
plt.show()

png

接下来,我们将定义要分割的目标上的点。让我们尝试分割卡车车窗玻璃上的坐标为 (284, 213) 的点。

# Define the input point prompt
input_point = np.array([[284, 213.5]])
input_label = np.array([1])

plt.figure(figsize=(10, 10))
plt.imshow(ops.convert_to_numpy(image) / 255.0)
show_points(input_point, input_label, plt.gca())
plt.axis("on")
plt.show()

png

现在,让我们调用模型的 predict 方法来获取分割遮罩。

注意:我们没有直接调用模型 (model(...)),因为需要占位符提示才能这样做。predict 方法会自动处理缺少的提示,因此我们调用它。此外,当不存在框提示时,需要用零点提示和 -1 标签提示填充点和标签。下面的单元格展示了如何实现这一点。

outputs = model.predict(
    {
        "images": image[np.newaxis, ...],
        "points": np.concatenate(
            [input_point[np.newaxis, ...], np.zeros((1, 1, 2))], axis=1
        ),
        "labels": np.concatenate(
            [input_label[np.newaxis, ...], np.full((1, 1), fill_value=-1)], axis=1
        ),
    }
)
Could not load symbol cuFuncGetName. Error: /usr/lib64-nvidia/libcuda.so.1: undefined symbol: cuFuncGetName

 1/1 ━━━━━━━━━━━━━━━━━━━━ 24s 24s/step

SegmentAnythingModel.predict 返回两个输出。第一个是形状为 (1, 4, 256, 256) 的 logits(分割遮罩),另一个是每个预测遮罩的 IoU 置信度分数(形状为 (1, 4))。预训练的 SAM 模型预测了四个遮罩:第一个是模型针对给定提示所能预测的最佳遮罩,另外三个是备用遮罩,可以在最佳预测不包含所需物体的情况下使用。用户可以选择他们喜欢的任何遮罩。

让我们可视化模型返回的遮罩!

# Resize the mask to our image shape i.e. (1024, 1024)
mask = inference_resizing(outputs["masks"][0][0][..., None], pad=False)[..., 0]
# Convert the logits to a numpy array
# and convert the logits to a boolean mask
mask = ops.convert_to_numpy(mask) > 0.0
iou_score = ops.convert_to_numpy(outputs["iou_pred"][0][0])

plt.figure(figsize=(10, 10))
plt.imshow(ops.convert_to_numpy(image) / 255.0)
show_mask(mask, plt.gca())
show_points(input_point, input_label, plt.gca())
plt.title(f"IoU Score: {iou_score:.3f}", fontsize=18)
plt.axis("off")
plt.show()

png

正如预期的那样,模型返回了卡车车窗玻璃的分割遮罩。但是,我们的点提示也可以表示其他各种事物。例如,另一个包含我们点的可能遮罩只是车窗玻璃的右侧或整辆卡车。

让我们也可视化模型预测的其他遮罩。

fig, ax = plt.subplots(1, 3, figsize=(20, 60))
masks, scores = outputs["masks"][0][1:], outputs["iou_pred"][0][1:]
for i, (mask, score) in enumerate(zip(masks, scores)):
    mask = inference_resizing(mask[..., None], pad=False)[..., 0]
    mask, score = map(ops.convert_to_numpy, (mask, score))
    mask = 1 * (mask > 0.0)
    ax[i].imshow(ops.convert_to_numpy(image) / 255.0)
    show_mask(mask, ax[i])
    show_points(input_point, input_label, ax[i])
    ax[i].set_title(f"Mask {i+1}, Score: {score:.3f}", fontsize=12)
    ax[i].axis("off")
plt.show()

png

不错!SAM 能够捕捉到我们点提示的模糊性,并返回了其他可能的分割遮罩。


框提示

现在,让我们看看如何使用框来提示模型。框使用两个点来指定,即边界框的左上角和右下角,采用 xyxy 格式。让我们使用围绕卡车左前轮的边界框来提示模型。

# Let's specify the box
input_box = np.array([[240, 340], [400, 500]])

outputs = model.predict(
    {"images": image[np.newaxis, ...], "boxes": input_box[np.newaxis, np.newaxis, ...]}
)
mask = inference_resizing(outputs["masks"][0][0][..., None], pad=False)[..., 0]
mask = ops.convert_to_numpy(mask) > 0.0

plt.figure(figsize=(10, 10))
plt.imshow(ops.convert_to_numpy(image) / 255.0)
show_mask(mask, plt.gca())
show_box(input_box, plt.gca())
plt.axis("off")
plt.show()
 1/1 ━━━━━━━━━━━━━━━━━━━━ 10s 10s/step

png

太棒了!模型完美地分割了我们边界框中的左前轮。


组合提示

为了充分发挥模型的潜力,让我们将框提示和点提示组合起来,看看模型会怎么做。

# Let's specify the box
input_box = np.array([[240, 340], [400, 500]])
# Let's specify the point and mark it background
input_point = np.array([[325, 425]])
input_label = np.array([0])

outputs = model.predict(
    {
        "images": image[np.newaxis, ...],
        "points": input_point[np.newaxis, ...],
        "labels": input_label[np.newaxis, ...],
        "boxes": input_box[np.newaxis, np.newaxis, ...],
    }
)
mask = inference_resizing(outputs["masks"][0][0][..., None], pad=False)[..., 0]
mask = ops.convert_to_numpy(mask) > 0.0

plt.figure(figsize=(10, 10))
plt.imshow(ops.convert_to_numpy(image) / 255.0)
show_mask(mask, plt.gca())
show_box(input_box, plt.gca())
show_points(input_point, input_label, plt.gca())
plt.axis("off")
plt.show()
 1/1 ━━━━━━━━━━━━━━━━━━━━ 14s 14s/step

png

太棒了!模型理解了我们想要从遮罩中排除的目标是轮胎的轮圈。


文本提示

最后,让我们看看如何将文本提示与 KerasHub 的 SegmentAnythingModel 一起使用。

对于这个演示,我们将使用官方 Grounding DINO 模型。Grounding DINO 是一个模型,它将 (图像,文本) 对作为输入,并生成一个围绕 图像 中由 文本 描述的目标的边界框。你可以参考论文以了解有关模型实现的更多详细信息。

对于演示的这一部分,我们需要从源代码安装 groundingdino

pip install -U git+https://github.com/IDEA-Research/GroundingDINO.git

然后,我们可以安装预训练模型的权重和配置

!!pip install -U git+https://github.com/IDEA-Research/GroundingDINO.git
!!wget -q https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth
!!wget -q https://raw.githubusercontent.com/IDEA-Research/GroundingDINO/v0.1.0-alpha2/groundingdino/config/GroundingDINO_SwinT_OGC.py
from groundingdino.util.inference import Model as GroundingDINO

CONFIG_PATH = "GroundingDINO_SwinT_OGC.py"
WEIGHTS_PATH = "groundingdino_swint_ogc.pth"

grounding_dino = GroundingDINO(CONFIG_PATH, WEIGHTS_PATH)
['Collecting git+https://github.com/IDEA-Research/GroundingDINO.git',
 '  Cloning https://github.com/IDEA-Research/GroundingDINO.git to /tmp/pip-req-build-m_hhz04_',
 '  Running command git clone --filter=blob:none --quiet https://github.com/IDEA-Research/GroundingDINO.git /tmp/pip-req-build-m_hhz04_',
 '  Resolved https://github.com/IDEA-Research/GroundingDINO.git to commit 856dde20aee659246248e20734ef9ba5214f5e44',
 '  Preparing metadata (setup.py) ... \x1b[?25l\x1b[?25hdone',
 'Requirement already satisfied: torch in /usr/local/lib/python3.10/dist-packages (from groundingdino==0.1.0) (2.4.1+cu121)',
 'Requirement already satisfied: torchvision in /usr/local/lib/python3.10/dist-packages (from groundingdino==0.1.0) (0.19.1+cu121)',
 'Requirement already satisfied: transformers in /usr/local/lib/python3.10/dist-packages (from groundingdino==0.1.0) (4.44.2)',
 'Collecting addict (from groundingdino==0.1.0)',
 '  Downloading addict-2.4.0-py3-none-any.whl.metadata (1.0 kB)',
 'Collecting yapf (from groundingdino==0.1.0)',
 '  Downloading yapf-0.40.2-py3-none-any.whl.metadata (45 kB)',
 '\x1b[?25l     \x1b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\x1b[0m \x1b[32m0.0/45.4 kB\x1b[0m \x1b[31m?\x1b[0m eta \x1b[36m-:--:--\x1b[0m',
 '\x1b[2K     \x1b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\x1b[0m \x1b[32m45.4/45.4 kB\x1b[0m \x1b[31m1.8 MB/s\x1b[0m eta \x1b[36m0:00:00\x1b[0m',
 '\x1b[?25hCollecting timm (from groundingdino==0.1.0)',
 '  Downloading timm-1.0.9-py3-none-any.whl.metadata (42 kB)',
 '\x1b[?25l     \x1b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\x1b[0m \x1b[32m0.0/42.4 kB\x1b[0m \x1b[31m?\x1b[0m eta \x1b[36m-:--:--\x1b[0m',
 '\x1b[2K     \x1b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\x1b[0m \x1b[32m42.4/42.4 kB\x1b[0m \x1b[31m1.8 MB/s\x1b[0m eta \x1b[36m0:00:00\x1b[0m',
 '\x1b[?25hRequirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from groundingdino==0.1.0) (1.26.4)',
 'Requirement already satisfied: opencv-python in /usr/local/lib/python3.10/dist-packages (from groundingdino==0.1.0) (4.10.0.84)',
 'Collecting supervision>=0.22.0 (from groundingdino==0.1.0)',
 '  Downloading supervision-0.23.0-py3-none-any.whl.metadata (14 kB)',
 'Requirement already satisfied: pycocotools in /usr/local/lib/python3.10/dist-packages (from groundingdino==0.1.0) (2.0.8)',
 'Requirement already satisfied: defusedxml<0.8.0,>=0.7.1 in /usr/local/lib/python3.10/dist-packages (from supervision>=0.22.0->groundingdino==0.1.0) (0.7.1)',
 'Requirement already satisfied: matplotlib>=3.6.0 in /usr/local/lib/python3.10/dist-packages (from supervision>=0.22.0->groundingdino==0.1.0) (3.7.1)',
 'Requirement already satisfied: opencv-python-headless>=4.5.5.64 in /usr/local/lib/python3.10/dist-packages (from supervision>=0.22.0->groundingdino==0.1.0) (4.10.0.84)',
 'Requirement already satisfied: pillow>=9.4 in /usr/local/lib/python3.10/dist-packages (from supervision>=0.22.0->groundingdino==0.1.0) (10.4.0)',
 'Requirement already satisfied: pyyaml>=5.3 in /usr/local/lib/python3.10/dist-packages (from supervision>=0.22.0->groundingdino==0.1.0) (6.0.2)',
 'Requirement already satisfied: scipy<2.0.0,>=1.10.0 in /usr/local/lib/python3.10/dist-packages (from supervision>=0.22.0->groundingdino==0.1.0) (1.13.1)',
 'Requirement already satisfied: huggingface_hub in /usr/local/lib/python3.10/dist-packages (from timm->groundingdino==0.1.0) (0.24.7)',
 'Requirement already satisfied: safetensors in /usr/local/lib/python3.10/dist-packages (from timm->groundingdino==0.1.0) (0.4.5)',
 'Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch->groundingdino==0.1.0) (3.16.1)',
 'Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.10/dist-packages (from torch->groundingdino==0.1.0) (4.12.2)',
 'Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch->groundingdino==0.1.0) (1.13.3)',
 'Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch->groundingdino==0.1.0) (3.3)',
 'Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch->groundingdino==0.1.0) (3.1.4)',
 'Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch->groundingdino==0.1.0) (2024.6.1)',
 'Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from transformers->groundingdino==0.1.0) (24.1)',
 'Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers->groundingdino==0.1.0) (2024.9.11)',
 'Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from transformers->groundingdino==0.1.0) (2.32.3)',
 'Requirement already satisfied: tokenizers<0.20,>=0.19 in /usr/local/lib/python3.10/dist-packages (from transformers->groundingdino==0.1.0) (0.19.1)',
 'Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers->groundingdino==0.1.0) (4.66.5)',
 'Requirement already satisfied: importlib-metadata>=6.6.0 in /usr/local/lib/python3.10/dist-packages (from yapf->groundingdino==0.1.0) (8.4.0)',
 'Requirement already satisfied: platformdirs>=3.5.1 in /usr/local/lib/python3.10/dist-packages (from yapf->groundingdino==0.1.0) (4.3.6)',
 'Requirement already satisfied: tomli>=2.0.1 in /usr/local/lib/python3.10/dist-packages (from yapf->groundingdino==0.1.0) (2.0.1)',
 'Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.10/dist-packages (from importlib-metadata>=6.6.0->yapf->groundingdino==0.1.0) (3.20.2)',
 'Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib>=3.6.0->supervision>=0.22.0->groundingdino==0.1.0) (1.3.0)',
 'Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.10/dist-packages (from matplotlib>=3.6.0->supervision>=0.22.0->groundingdino==0.1.0) (0.12.1)',
 'Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib>=3.6.0->supervision>=0.22.0->groundingdino==0.1.0) (4.54.1)',
 'Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib>=3.6.0->supervision>=0.22.0->groundingdino==0.1.0) (1.4.7)',
 'Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib>=3.6.0->supervision>=0.22.0->groundingdino==0.1.0) (3.1.4)',
 'Requirement already satisfied: python-dateutil>=2.7 in /usr/local/lib/python3.10/dist-packages (from matplotlib>=3.6.0->supervision>=0.22.0->groundingdino==0.1.0) (2.8.2)',
 'Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch->groundingdino==0.1.0) (2.1.5)',
 'Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->transformers->groundingdino==0.1.0) (3.3.2)',
 'Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->transformers->groundingdino==0.1.0) (3.10)',
 'Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->transformers->groundingdino==0.1.0) (2.2.3)',
 'Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->transformers->groundingdino==0.1.0) (2024.8.30)',
 'Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from sympy->torch->groundingdino==0.1.0) (1.3.0)',
 'Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.7->matplotlib>=3.6.0->supervision>=0.22.0->groundingdino==0.1.0) (1.16.0)',
 'Downloading supervision-0.23.0-py3-none-any.whl (151 kB)',
 '\x1b[?25l   \x1b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\x1b[0m \x1b[32m0.0/151.5 kB\x1b[0m \x1b[31m?\x1b[0m eta \x1b[36m-:--:--\x1b[0m',
 '\x1b[2K   \x1b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\x1b[0m \x1b[32m151.5/151.5 kB\x1b[0m \x1b[31m6.0 MB/s\x1b[0m eta \x1b[36m0:00:00\x1b[0m',
 '\x1b[?25hDownloading addict-2.4.0-py3-none-any.whl (3.8 kB)',
 'Downloading timm-1.0.9-py3-none-any.whl (2.3 MB)',
 '\x1b[?25l   \x1b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\x1b[0m \x1b[32m0.0/2.3 MB\x1b[0m \x1b[31m?\x1b[0m eta \x1b[36m-:--:--\x1b[0m',
 '\x1b[2K   \x1b[91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\x1b[0m\x1b[90m╺\x1b[0m\x1b[90m━━━━━━━\x1b[0m \x1b[32m1.9/2.3 MB\x1b[0m \x1b[31m55.9 MB/s\x1b[0m eta \x1b[36m0:00:01\x1b[0m',
 '\x1b[2K   \x1b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\x1b[0m \x1b[32m2.3/2.3 MB\x1b[0m \x1b[31m42.4 MB/s\x1b[0m eta \x1b[36m0:00:00\x1b[0m',
 '\x1b[?25hDownloading yapf-0.40.2-py3-none-any.whl (254 kB)',
 '\x1b[?25l   \x1b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\x1b[0m \x1b[32m0.0/254.7 kB\x1b[0m \x1b[31m?\x1b[0m eta \x1b[36m-:--:--\x1b[0m',
 '\x1b[2K   \x1b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\x1b[0m \x1b[32m254.7/254.7 kB\x1b[0m \x1b[31m18.3 MB/s\x1b[0m eta \x1b[36m0:00:00\x1b[0m',
 '\x1b[?25hBuilding wheels for collected packages: groundingdino',
 '  Building wheel for groundingdino (setup.py) ... \x1b[?25l\x1b[?25hdone',
 '  Created wheel for groundingdino: filename=groundingdino-0.1.0-cp310-cp310-linux_x86_64.whl size=3038498 sha256=1e7306dfa5ebd4bebb340bfe814e13026800708bbc0223d37ae8963e90145fb2',
 '  Stored in directory: /tmp/pip-ephem-wheel-cache-multbs74/wheels/6b/06/d7/b57f601a4df56af41d262a5b1b496359b13c323bf5ef0434b2',
 'Successfully built groundingdino',
 'Installing collected packages: addict, yapf, supervision, timm, groundingdino',
 'Successfully installed addict-2.4.0 groundingdino-0.1.0 supervision-0.23.0 timm-1.0.9 yapf-0.40.2']

[]

UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3609.)

final text_encoder_type: bert-base-uncased

UserWarning: 
Error while fetching `HF_TOKEN` secret value from your vault: 'Requesting secret HF_TOKEN timed out. Secrets can only be fetched when running from the Colab UI.'.
You are not authenticated with the Hugging Face Hub in this notebook.
If the error persists, please let us know by opening an issue on GitHub (https://github.com/huggingface/huggingface_hub/issues/new).

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

FutureWarning: `clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be depracted in transformers v4.45, and will be then set to `False` by default. For more details check this issue: https://github.com/huggingface/transformers/issues/31884

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.

让我们为这一部分加载一张狗的图像!

filepath = keras.utils.get_file(
    origin="https://storage.googleapis.com/keras-cv/test-images/mountain-dog.jpeg"
)
image = np.array(keras.utils.load_img(filepath))
image = ops.convert_to_numpy(inference_resizing(image))

plt.figure(figsize=(10, 10))
plt.imshow(image / 255.0)
plt.axis("on")
plt.show()
Downloading data from https://storage.googleapis.com/keras-cv/test-images/mountain-dog.jpeg
 1236492/1236492 ━━━━━━━━━━━━━━━━━━━━ 0s 0us/step

WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

png

我们首先使用 Grounding DINO 模型预测我们要分割的目标的边界框。然后,我们使用边界框来提示 SAM 模型,以获取分割遮罩。

让我们尝试分割狗的背带。更改下面的图像和文本,使用图像中的文本分割任何你想要分割的东西!

# Let's predict the bounding box for the harness of the dog
boxes = grounding_dino.predict_with_caption(image.astype(np.uint8), "harness")
boxes = np.array(boxes[0].xyxy)

outputs = model.predict(
    {
        "images": np.repeat(image[np.newaxis, ...], boxes.shape[0], axis=0),
        "boxes": boxes.reshape(-1, 1, 2, 2),
    },
    batch_size=1,
)
FutureWarning: The `device` argument is deprecated and will be removed in v5 of Transformers.
UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.
UserWarning: None of the inputs have requires_grad=True. Gradients will be None
FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.

 1/1 ━━━━━━━━━━━━━━━━━━━━ 10s 10s/step

就是这样!我们使用 Grounding DINO + SAM 的组合,为我们的文本提示获得了分割遮罩!这是一个非常强大的技术,可以将不同的模型组合起来,以扩展应用范围!

让我们可视化结果。

plt.figure(figsize=(10, 10))
plt.imshow(image / 255.0)

for mask in outputs["masks"]:
    mask = inference_resizing(mask[0][..., None], pad=False)[..., 0]
    mask = ops.convert_to_numpy(mask) > 0.0
    show_mask(mask, plt.gca())
    show_box(boxes, plt.gca())

plt.axis("off")
plt.show()
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

png


优化 SAM

你可以使用 mixed_float16bfloat16 dtype 策略来获得巨大的加速和内存优化,而精度损失相对较小。

# Load our image
image = np.array(keras.utils.load_img("truck.jpg"))
image = inference_resizing(image)

# Specify the prompt
input_box = np.array([[240, 340], [400, 500]])

# Let's first see how fast the model is with float32 dtype
time_taken = timeit.repeat(
    'model.predict({"images": image[np.newaxis, ...], "boxes": input_box[np.newaxis, np.newaxis, ...]}, verbose=False)',
    repeat=3,
    number=3,
    globals=globals(),
)
print(f"Time taken with float32 dtype: {min(time_taken) / 3:.10f}s")

# Set the dtype policy in Keras
keras.mixed_precision.set_global_policy("mixed_float16")

model = keras_hub.models.SAMImageSegmenter.from_preset("sam_huge_sa1b")

time_taken = timeit.repeat(
    'model.predict({"images": image[np.newaxis, ...], "boxes": input_box[np.newaxis,np.newaxis, ...]}, verbose=False)',
    repeat=3,
    number=3,
    globals=globals(),
)
print(f"Time taken with float16 dtype: {min(time_taken) / 3:.10f}s")
Time taken with float32 dtype: 0.2298811787s

UserWarning: Skipping variable loading for optimizer 'loss_scale_optimizer', because it has 4 variables whereas the saved optimizer has 2 variables. 
UserWarning: Skipping variable loading for optimizer 'adam', because it has 2 variables whereas the saved optimizer has 0 variables. 

Time taken with float16 dtype: 0.2068303013s

以下是对 KerasHub 实现与原始 PyTorch 实现的比较!

benchmark

用于生成基准测试的脚本位于这里


结论

KerasHub 的 SegmentAnythingModel 支持各种应用程序,并且借助 Keras 3,可以在 TensorFlow、JAX 和 PyTorch 上运行模型!借助 JAX 和 TensorFlow 中的 XLA,该模型的运行速度比原始实现快几倍。此外,使用 Keras 的混合精度支持仅需一行代码即可优化内存使用和计算时间!

有关更高级的使用方法,请查看自动遮罩生成器演示