代码示例 / Keras 快速技巧 / 创建 TFRecords

创建 TFRecords

作者: Dimitre Oliveira
创建日期 2021/02/27
最后修改日期 2023/12/20
描述: 将数据转换为 TFRecord 格式。

ⓘ 此示例使用 Keras 3

在 Colab 中查看 GitHub 源代码


简介

TFRecord 格式是用于存储一系列二进制记录的简单格式。将您的数据转换为 TFRecord 有许多优点,例如

  • 更高效的存储:TFRecord 数据可以占用比原始数据更少的空间;它还可以被划分为多个文件。
  • 快速 I/O:可以使用并行 I/O 操作读取 TFRecord 格式,这对于 TPU 或多个主机很有用。
  • 自包含文件:TFRecord 数据可以从单个来源读取——例如,COCO2017 数据集最初将数据存储在两个文件夹中(“images”和“annotations”)。

TFRecord 数据格式的一个重要用例是在 TPU 上进行训练。首先,TPU 速度足够快,可以从优化的 I/O 操作中受益。此外,TPU 要求数据远程存储(例如,在 Google Cloud Storage 上),使用 TFRecord 格式可以更轻松地加载数据,而无需批量下载。

如果您还将 TFRecord 格式与 tf.data API 一起使用,则可以进一步提高性能。

在此示例中,您将学习如何将不同类型的数据(图像、文本和数字)转换为 TFRecord。

参考


依赖项

import os

os.environ["KERAS_BACKEND"] = "tensorflow"
import keras
import json
import pprint
import tensorflow as tf
import matplotlib.pyplot as plt

下载 COCO2017 数据集

我们将使用 COCO2017 数据集,因为它具有许多不同类型的特征,包括图像、浮点数据和列表。它将很好地演示如何将不同的特征编码为 TFRecord 格式。

此数据集有两组字段:图像和注释元数据。

图像是 JPG 文件的集合,元数据存储在 JSON 文件中,根据 官方网站,该文件包含以下属性

id: int,
image_id: int,
category_id: int,
segmentation: RLE or [polygon], object segmentation mask
bbox: [x,y,width,height], object bounding box coordinates
area: float, area of the bounding box
iscrowd: 0 or 1, is single object or a collection
root_dir = "datasets"
tfrecords_dir = "tfrecords"
images_dir = os.path.join(root_dir, "val2017")
annotations_dir = os.path.join(root_dir, "annotations")
annotation_file = os.path.join(annotations_dir, "instances_val2017.json")
images_url = "http://images.cocodataset.org/zips/val2017.zip"
annotations_url = (
    "http://images.cocodataset.org/annotations/annotations_trainval2017.zip"
)

# Download image files
if not os.path.exists(images_dir):
    image_zip = keras.utils.get_file(
        "images.zip",
        cache_dir=os.path.abspath("."),
        origin=images_url,
        extract=True,
    )
    os.remove(image_zip)

# Download caption annotation files
if not os.path.exists(annotations_dir):
    annotation_zip = keras.utils.get_file(
        "captions.zip",
        cache_dir=os.path.abspath("."),
        origin=annotations_url,
        extract=True,
    )
    os.remove(annotation_zip)

print("The COCO dataset has been downloaded and extracted successfully.")

with open(annotation_file, "r") as f:
    annotations = json.load(f)["annotations"]

print(f"Number of images: {len(annotations)}")
Downloading data from http://images.cocodataset.org/zips/val2017.zip
 815585330/815585330 ━━━━━━━━━━━━━━━━━━━━ 79s 0us/step
Downloading data from http://images.cocodataset.org/annotations/annotations_trainval2017.zip
 252907541/252907541 ━━━━━━━━━━━━━━━━━━━━ 5s 0us/step
The COCO dataset has been downloaded and extracted successfully.
Number of images: 36781

COCO2017 数据集的内容

pprint.pprint(annotations[60])
{'area': 367.89710000000014,
 'bbox': [265.67, 222.31, 26.48, 14.71],
 'category_id': 72,
 'id': 34096,
 'image_id': 525083,
 'iscrowd': 0,
 'segmentation': [[267.51,
                   222.31,
                   292.15,
                   222.31,
                   291.05,
                   237.02,
                   265.67,
                   237.02]]}

参数

num_samples 是每个 TFRecord 文件上的数据样本数。

num_tfrecords 是我们将创建的 TFRecord 总数。

num_samples = 4096
num_tfrecords = len(annotations) // num_samples
if len(annotations) % num_samples:
    num_tfrecords += 1  # add one record if there are any remaining samples

if not os.path.exists(tfrecords_dir):
    os.makedirs(tfrecords_dir)  # creating TFRecords output folder

定义 TFRecords 辅助函数

def image_feature(value):
    """Returns a bytes_list from a string / byte."""
    return tf.train.Feature(
        bytes_list=tf.train.BytesList(value=[tf.io.encode_jpeg(value).numpy()])
    )


def bytes_feature(value):
    """Returns a bytes_list from a string / byte."""
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value.encode()]))


def float_feature(value):
    """Returns a float_list from a float / double."""
    return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))


def int64_feature(value):
    """Returns an int64_list from a bool / enum / int / uint."""
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))


def float_feature_list(value):
    """Returns a list of float_list from a float / double."""
    return tf.train.Feature(float_list=tf.train.FloatList(value=value))


def create_example(image, path, example):
    feature = {
        "image": image_feature(image),
        "path": bytes_feature(path),
        "area": float_feature(example["area"]),
        "bbox": float_feature_list(example["bbox"]),
        "category_id": int64_feature(example["category_id"]),
        "id": int64_feature(example["id"]),
        "image_id": int64_feature(example["image_id"]),
    }
    return tf.train.Example(features=tf.train.Features(feature=feature))


def parse_tfrecord_fn(example):
    feature_description = {
        "image": tf.io.FixedLenFeature([], tf.string),
        "path": tf.io.FixedLenFeature([], tf.string),
        "area": tf.io.FixedLenFeature([], tf.float32),
        "bbox": tf.io.VarLenFeature(tf.float32),
        "category_id": tf.io.FixedLenFeature([], tf.int64),
        "id": tf.io.FixedLenFeature([], tf.int64),
        "image_id": tf.io.FixedLenFeature([], tf.int64),
    }
    example = tf.io.parse_single_example(example, feature_description)
    example["image"] = tf.io.decode_jpeg(example["image"], channels=3)
    example["bbox"] = tf.sparse.to_dense(example["bbox"])
    return example

以 TFRecord 格式生成数据

让我们以 TFRecord 格式生成 COCO2017 数据。格式将为 file_{number}.tfrec (这是可选的,但在文件名中包含数字序列可以使计数更容易)。

for tfrec_num in range(num_tfrecords):
    samples = annotations[(tfrec_num * num_samples) : ((tfrec_num + 1) * num_samples)]

    with tf.io.TFRecordWriter(
        tfrecords_dir + "/file_%.2i-%i.tfrec" % (tfrec_num, len(samples))
    ) as writer:
        for sample in samples:
            image_path = f"{images_dir}/{sample['image_id']:012d}.jpg"
            image = tf.io.decode_jpeg(tf.io.read_file(image_path))
            example = create_example(image, image_path, sample)
            writer.write(example.SerializeToString())

浏览生成的 TFRecord 中的一个样本

raw_dataset = tf.data.TFRecordDataset(f"{tfrecords_dir}/file_00-{num_samples}.tfrec")
parsed_dataset = raw_dataset.map(parse_tfrecord_fn)

for features in parsed_dataset.take(1):
    for key in features.keys():
        if key != "image":
            print(f"{key}: {features[key]}")

    print(f"Image shape: {features['image'].shape}")
    plt.figure(figsize=(7, 7))
    plt.imshow(features["image"].numpy())
    plt.show()
bbox: [473.07 395.93  38.65  28.67]
area: 702.1057739257812
category_id: 18
id: 1768
image_id: 289343
path: b'datasets/val2017/000000289343.jpg'
Image shape: (640, 529, 3)

png


使用生成的 TFRecords 训练一个简单的模型

TFRecord 的另一个优点是您可以在其中添加许多特征,稍后只使用其中的几个,在这种情况下,我们将只使用 imagecategory_id


定义数据集辅助函数

def prepare_sample(features):
    image = keras.ops.image.resize(features["image"], size=(224, 224))
    return image, features["category_id"]


def get_dataset(filenames, batch_size):
    dataset = (
        tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTOTUNE)
        .map(parse_tfrecord_fn, num_parallel_calls=AUTOTUNE)
        .map(prepare_sample, num_parallel_calls=AUTOTUNE)
        .shuffle(batch_size * 10)
        .batch(batch_size)
        .prefetch(AUTOTUNE)
    )
    return dataset


train_filenames = tf.io.gfile.glob(f"{tfrecords_dir}/*.tfrec")
batch_size = 32
epochs = 1
steps_per_epoch = 50
AUTOTUNE = tf.data.AUTOTUNE

input_tensor = keras.layers.Input(shape=(224, 224, 3), name="image")
model = keras.applications.EfficientNetB0(
    input_tensor=input_tensor, weights=None, classes=91
)


model.compile(
    optimizer=keras.optimizers.Adam(),
    loss=keras.losses.SparseCategoricalCrossentropy(),
    metrics=[keras.metrics.SparseCategoricalAccuracy()],
)


model.fit(
    x=get_dataset(train_filenames, batch_size),
    epochs=epochs,
    steps_per_epoch=steps_per_epoch,
    verbose=1,
)
 50/50 ━━━━━━━━━━━━━━━━━━━━ 146s 2s/step - loss: 3.9206 - sparse_categorical_accuracy: 0.1690

<keras.src.callbacks.history.History at 0x7f70684c27a0>

结论

此示例演示了,借助 TFRecord,您可以让数据来自单个来源,而不是从不同的来源读取图像和注释。此过程可以使存储和读取数据更简单、更有效。有关更多信息,您可以访问 TFRecord 和 tf.train.Example 教程。