作者: Tirth Patel, Ian Stenbit, Divyashree Sreepathihalli
创建日期 2024/10/1
最后修改 2024/10/1
描述: 在 KerasHub 中使用文本、框和点提示进行 Segment Anything。
Segment Anything Model (SAM) 根据点或框等输入提示生成高质量的对象掩码,也可用于生成图像中所有对象的掩码。它在包含 1100 万张图像和 11 亿个掩码的数据集上进行了训练,在各种分割任务上具有强大的零样本性能。
在本指南中,我们将展示如何使用 KerasHub 对 Segment Anything Model 的实现,并展示 TensorFlow 和 JAX 在性能提升方面的强大之处。
首先,让我们获取所有依赖项和演示所需的图像。
!!pip install -Uq git+https://github.com/keras-team/keras-hub.git
!!pip install -Uq keras
!!wget -q https://raw.githubusercontent.com/facebookresearch/segment-anything/main/notebooks/images/truck.jpg
使用 Keras 3,你可以选择你喜欢的后端!
import os
os.environ["KERAS_BACKEND"] = "jax"
import timeit
import numpy as np
import matplotlib.pyplot as plt
import keras
from keras import ops
import keras_hub
让我们定义一些辅助函数,用于可视化图像、提示和分割结果。
def show_mask(mask, ax, random_color=False):
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
ax.imshow(mask_image)
def show_points(coords, labels, ax, marker_size=375):
pos_points = coords[labels == 1]
neg_points = coords[labels == 0]
ax.scatter(
pos_points[:, 0],
pos_points[:, 1],
color="green",
marker="*",
s=marker_size,
edgecolor="white",
linewidth=1.25,
)
ax.scatter(
neg_points[:, 0],
neg_points[:, 1],
color="red",
marker="*",
s=marker_size,
edgecolor="white",
linewidth=1.25,
)
def show_box(box, ax):
box = box.reshape(-1)
x0, y0 = box[0], box[1]
w, h = box[2] - box[0], box[3] - box[1]
ax.add_patch(
plt.Rectangle((x0, y0), w, h, edgecolor="green", facecolor=(0, 0, 0, 0), lw=2)
)
def inference_resizing(image, pad=True):
# Compute Preprocess Shape
image = ops.cast(image, dtype="float32")
old_h, old_w = image.shape[0], image.shape[1]
scale = 1024 * 1.0 / max(old_h, old_w)
new_h = old_h * scale
new_w = old_w * scale
preprocess_shape = int(new_h + 0.5), int(new_w + 0.5)
# Resize the image
image = ops.image.resize(image[None, ...], preprocess_shape)[0]
# Pad the shorter side
if pad:
pixel_mean = ops.array([123.675, 116.28, 103.53])
pixel_std = ops.array([58.395, 57.12, 57.375])
image = (image - pixel_mean) / pixel_std
h, w = image.shape[0], image.shape[1]
pad_h = 1024 - h
pad_w = 1024 - w
image = ops.pad(image, [(0, pad_h), (0, pad_w), (0, 0)])
# KerasHub now rescales the images and normalizes them.
# Just unnormalize such that when KerasHub normalizes them
# again, the padded values map to 0.
image = image * pixel_std + pixel_mean
return image
我们可以使用 KerasHub 的 from_preset
工厂方法初始化一个训练好的 SAM 模型。在这里,我们使用在 SA-1B 数据集上训练的巨大 ViT 主干 (`sam_huge_sa1b`) 来获得高质量的分割掩码。你也可以使用 `sam_large_sa1b` 或 `sam_base_sa1b` 中的一个来获得更好的性能(但代价是分割掩码质量会降低)。
model = keras_hub.models.SAMImageSegmenter.from_preset("sam_huge_sa1b")
Segment Anything 允许使用点、框和掩码来提示图像
使该模型极其强大的是能够组合上述提示的能力。点、框和掩码提示可以以多种不同方式组合使用,以达到最佳结果。
让我们看看在 KerasHub 中将这些提示传递给 Segment Anything 模型时的语义。SAM 模型的输入是一个字典,包含以下键:
"images"
: 待分割的图像批次。形状必须为 (B, 1024, 1024, 3)
。"points"
: 点提示批次。每个点是图像左上角开始的 (x, y)
坐标。换句话说,每个点都采用 (r, c)
形式,其中 r
和 c
是图像中像素的行和列。形状必须为 (B, N, 2)
。"labels"
: 给定点的标签批次。1
表示前景点,0
表示背景点。形状必须为 (B, N)
。"boxes"
: 框批次。请注意,模型每批次只接受一个框。因此,期望的形状是 (B, 1, 2, 2)
。每个框是两个点的集合:框的左上角和右下角。这里的点遵循与点提示相同的语义。这里的第二维中的 1
表示存在框提示。如果框提示缺失,必须传递形状为 (B, 0, 2, 2)
的占位符输入。"masks"
: 掩码批次。与框提示类似,每张图像只允许一个掩码提示。如果存在输入掩码,其形状必须为 (B, 1, 256, 256, 1)
,如果掩码提示缺失,则为 (B, 0, 256, 256, 1)
。占位符提示仅在直接调用模型(即 model(...)
)时需要。调用 predict
方法时,输入字典中可以省略缺失的提示。
首先,让我们使用点提示来分割图像。我们加载图像并将其大小调整为 (1024, 1024)
,这是预训练 SAM 模型期望的图像尺寸。
# Load our image
image = np.array(keras.utils.load_img("truck.jpg"))
image = inference_resizing(image)
plt.figure(figsize=(10, 10))
plt.imshow(ops.convert_to_numpy(image) / 255.0)
plt.axis("on")
plt.show()
接下来,我们将定义要分割对象上的点。让我们尝试分割坐标为 (284, 213)
的卡车车窗玻璃。
# Define the input point prompt
input_point = np.array([[284, 213.5]])
input_label = np.array([1])
plt.figure(figsize=(10, 10))
plt.imshow(ops.convert_to_numpy(image) / 255.0)
show_points(input_point, input_label, plt.gca())
plt.axis("on")
plt.show()
现在让我们调用模型的 predict
方法来获取分割掩码。
注意:我们不直接调用模型 (model(...)
),因为需要占位符提示。缺失的提示由 predict 方法自动处理,所以我们改为调用它。此外,当没有框提示时,点和标签需要分别用零点提示和 -1
标签提示进行填充。下面的单元格演示了这是如何工作的。
outputs = model.predict(
{
"images": image[np.newaxis, ...],
"points": np.concatenate(
[input_point[np.newaxis, ...], np.zeros((1, 1, 2))], axis=1
),
"labels": np.concatenate(
[input_label[np.newaxis, ...], np.full((1, 1), fill_value=-1)], axis=1
),
}
)
SegmentAnythingModel.predict
返回两个输出。第一个是 logits (分割掩码),形状为 (1, 4, 256, 256)
,另一个是每个预测掩码的 IoU 置信度分数(形状为 (1, 4)
)。预训练的 SAM 模型预测四个掩码:第一个是模型能为给定提示提供的最佳掩码,其他 3 个是备选掩码,用于当最佳预测不包含所需对象时。用户可以选择他们喜欢的任何一个掩码。
让我们可视化模型返回的掩码!
# Resize the mask to our image shape i.e. (1024, 1024)
mask = inference_resizing(outputs["masks"][0][0][..., None], pad=False)[..., 0]
# Convert the logits to a numpy array
# and convert the logits to a boolean mask
mask = ops.convert_to_numpy(mask) > 0.0
iou_score = ops.convert_to_numpy(outputs["iou_pred"][0][0])
plt.figure(figsize=(10, 10))
plt.imshow(ops.convert_to_numpy(image) / 255.0)
show_mask(mask, plt.gca())
show_points(input_point, input_label, plt.gca())
plt.title(f"IoU Score: {iou_score:.3f}", fontsize=18)
plt.axis("off")
plt.show()
正如预期的那样,模型返回了卡车车窗玻璃的分割掩码。但是,我们的点提示也可能意味着一系列其他事物。例如,包含我们点的另一个可能掩码只是车窗玻璃的右侧或整个卡车。
让我们也可视化模型预测的其他掩码。
fig, ax = plt.subplots(1, 3, figsize=(20, 60))
masks, scores = outputs["masks"][0][1:], outputs["iou_pred"][0][1:]
for i, (mask, score) in enumerate(zip(masks, scores)):
mask = inference_resizing(mask[..., None], pad=False)[..., 0]
mask, score = map(ops.convert_to_numpy, (mask, score))
mask = 1 * (mask > 0.0)
ax[i].imshow(ops.convert_to_numpy(image) / 255.0)
show_mask(mask, ax[i])
show_points(input_point, input_label, ax[i])
ax[i].set_title(f"Mask {i+1}, Score: {score:.3f}", fontsize=12)
ax[i].axis("off")
plt.show()
很好!SAM 能够捕捉我们点提示的歧义性,并返回了其他可能的分割掩码。
现在,让我们看看如何使用框来提示模型。框是使用两个点指定的,即边界框的左上角和右下角,采用 xyxy 格式。让我们使用卡车左前轮胎周围的边界框来提示模型。
# Let's specify the box
input_box = np.array([[240, 340], [400, 500]])
outputs = model.predict(
{"images": image[np.newaxis, ...], "boxes": input_box[np.newaxis, np.newaxis, ...]}
)
mask = inference_resizing(outputs["masks"][0][0][..., None], pad=False)[..., 0]
mask = ops.convert_to_numpy(mask) > 0.0
plt.figure(figsize=(10, 10))
plt.imshow(ops.convert_to_numpy(image) / 255.0)
show_mask(mask, plt.gca())
show_box(input_box, plt.gca())
plt.axis("off")
plt.show()
棒极了!模型完美地分割出了我们边界框中的左前轮胎。
为了充分发挥模型的潜力,让我们结合框和点提示,看看模型会如何表现。
# Let's specify the box
input_box = np.array([[240, 340], [400, 500]])
# Let's specify the point and mark it background
input_point = np.array([[325, 425]])
input_label = np.array([0])
outputs = model.predict(
{
"images": image[np.newaxis, ...],
"points": input_point[np.newaxis, ...],
"labels": input_label[np.newaxis, ...],
"boxes": input_box[np.newaxis, np.newaxis, ...],
}
)
mask = inference_resizing(outputs["masks"][0][0][..., None], pad=False)[..., 0]
mask = ops.convert_to_numpy(mask) > 0.0
plt.figure(figsize=(10, 10))
plt.imshow(ops.convert_to_numpy(image) / 255.0)
show_mask(mask, plt.gca())
show_box(input_box, plt.gca())
show_points(input_point, input_label, plt.gca())
plt.axis("off")
plt.show()
瞧!模型理解了我们想要从掩码中排除的对象是轮胎的轮辋。
最后,让我们看看如何结合使用文本提示和 KerasHub 的 SegmentAnythingModel
。
对于本次演示,我们将使用官方的 Grounding DINO 模型。Grounding DINO 是一个模型,它接受 (图像, 文本)
对作为输入,并围绕 图像
中由 文本
描述的对象生成边界框。你可以参考论文了解该模型实现的更多细节。
对于本演示的这一部分,我们需要从源代码安装 groundingdino
包
pip install -U git+https://github.com/IDEA-Research/GroundingDINO.git
然后,我们可以安装预训练模型的权重和配置
!!pip install -U git+https://github.com/IDEA-Research/GroundingDINO.git
!!wget -q https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth
!!wget -q https://raw.githubusercontent.com/IDEA-Research/GroundingDINO/v0.1.0-alpha2/groundingdino/config/GroundingDINO_SwinT_OGC.py
from groundingdino.util.inference import Model as GroundingDINO
CONFIG_PATH = "GroundingDINO_SwinT_OGC.py"
WEIGHTS_PATH = "groundingdino_swint_ogc.pth"
grounding_dino = GroundingDINO(CONFIG_PATH, WEIGHTS_PATH)
让我们加载一张狗的图像来做这一部分!
filepath = keras.utils.get_file(
origin="https://storage.googleapis.com/keras-cv/test-images/mountain-dog.jpeg"
)
image = np.array(keras.utils.load_img(filepath))
image = ops.convert_to_numpy(inference_resizing(image))
plt.figure(figsize=(10, 10))
plt.imshow(image / 255.0)
plt.axis("on")
plt.show()
我们首先使用 Grounding DINO 模型预测我们想要分割的对象的边界框。然后,我们使用边界框提示 SAM 模型以获取分割掩码。
让我们尝试分割出狗的胸背带(harness)。修改下面的图像和文本,以便使用你图像中的文本来分割任何你想要的对象!
# Let's predict the bounding box for the harness of the dog
boxes = grounding_dino.predict_with_caption(image.astype(np.uint8), "harness")
boxes = np.array(boxes[0].xyxy)
outputs = model.predict(
{
"images": np.repeat(image[np.newaxis, ...], boxes.shape[0], axis=0),
"boxes": boxes.reshape(-1, 1, 2, 2),
},
batch_size=1,
)
就这样!通过结合 Grounding DINO + SAM,我们为文本提示获得了分割掩码!这是一种非常强大的技术,可以结合不同的模型来扩展应用范围!
让我们可视化结果。
plt.figure(figsize=(10, 10))
plt.imshow(image / 255.0)
for mask in outputs["masks"]:
mask = inference_resizing(mask[0][..., None], pad=False)[..., 0]
mask = ops.convert_to_numpy(mask) > 0.0
show_mask(mask, plt.gca())
show_box(boxes, plt.gca())
plt.axis("off")
plt.show()
你可以使用 mixed_float16
或 bfloat16
dtype 策略,以相对较低的精度损失获得巨大的速度提升和内存优化。
# Load our image
image = np.array(keras.utils.load_img("truck.jpg"))
image = inference_resizing(image)
# Specify the prompt
input_box = np.array([[240, 340], [400, 500]])
# Let's first see how fast the model is with float32 dtype
time_taken = timeit.repeat(
'model.predict({"images": image[np.newaxis, ...], "boxes": input_box[np.newaxis, np.newaxis, ...]}, verbose=False)',
repeat=3,
number=3,
globals=globals(),
)
print(f"Time taken with float32 dtype: {min(time_taken) / 3:.10f}s")
# Set the dtype policy in Keras
keras.mixed_precision.set_global_policy("mixed_float16")
model = keras_hub.models.SAMImageSegmenter.from_preset("sam_huge_sa1b")
time_taken = timeit.repeat(
'model.predict({"images": image[np.newaxis, ...], "boxes": input_box[np.newaxis,np.newaxis, ...]}, verbose=False)',
repeat=3,
number=3,
globals=globals(),
)
print(f"Time taken with float16 dtype: {min(time_taken) / 3:.10f}s")
这里是 KerasHub 实现与原始 PyTorch 实现的比较!
用于生成基准测试的脚本在此处提供此处。
KerasHub 的 SegmentAnythingModel
支持多种应用,并且借助 Keras 3,可以在 TensorFlow、JAX 和 PyTorch 上运行该模型!借助 JAX 和 TensorFlow 中的 XLA,该模型的运行速度比原始实现快很多倍。此外,使用 Keras 的混合精度支持仅需一行代码即可帮助优化内存使用和计算时间!
了解更多高级用法,请查看Automatic Mask Generator 演示。