作者: A_K_Nain, Sayak Paul
创建日期 2021/08/16
最后修改日期 2024/09/01
描述: 训练一个具有可变长度序列的手写识别模型。
此示例展示了如何将 验证码 OCR 示例扩展到 IAM 数据集,该数据集具有可变长度的真实标签。数据集中的每个样本都是一些手写文本的图像,其相应的目标是图像中存在的字符串。IAM 数据集广泛用于许多 OCR 基准测试,因此我们希望此示例可以作为构建 OCR 系统的良好起点。
!wget -q https://github.com/sayakpaul/Handwriting-Recognizer-in-Keras/releases/download/v1.0.0/IAM_Words.zip
!unzip -qq IAM_Words.zip
!
!mkdir data
!mkdir data/words
!tar -xf IAM_Words/words.tgz -C data/words
!mv IAM_Words/words.txt data
预览数据集的组织方式。以“#”开头的行只是元数据信息。
!head -20 data/words.txt
#--- words.txt ---------------------------------------------------------------#
#
# iam database word information
#
# format: a01-000u-00-00 ok 154 1 408 768 27 51 AT A
#
# a01-000u-00-00 -> word id for line 00 in form a01-000u
# ok -> result of word segmentation
# ok: word was correctly
# er: segmentation of word can be bad
#
# 154 -> graylevel to binarize the line containing this word
# 1 -> number of components for this word
# 408 768 27 51 -> bounding box around this word in x,y,w,h format
# AT -> the grammatical tag for this word, see the
# file tagset.txt for an explanation
# A -> the transcription for this word
#
a01-000u-00-00 ok 154 408 768 27 51 AT A
a01-000u-00-01 ok 154 507 766 213 48 NN MOVE
import keras
from keras.layers import StringLookup
from keras import ops
import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np
import os
np.random.seed(42)
keras.utils.set_random_seed(42)
base_path = "data"
words_list = []
words = open(f"{base_path}/words.txt", "r").readlines()
for line in words:
if line[0] == "#":
continue
if line.split(" ")[1] != "err": # We don't need to deal with errored entries.
words_list.append(line)
len(words_list)
np.random.shuffle(words_list)
我们将数据集以 90:5:5 的比例(训练:验证:测试)拆分为三个子集。
split_idx = int(0.9 * len(words_list))
train_samples = words_list[:split_idx]
test_samples = words_list[split_idx:]
val_split_idx = int(0.5 * len(test_samples))
validation_samples = test_samples[:val_split_idx]
test_samples = test_samples[val_split_idx:]
assert len(words_list) == len(train_samples) + len(validation_samples) + len(
test_samples
)
print(f"Total training samples: {len(train_samples)}")
print(f"Total validation samples: {len(validation_samples)}")
print(f"Total test samples: {len(test_samples)}")
Total training samples: 86810
Total validation samples: 4823
Total test samples: 4823
我们首先准备图像路径来构建数据输入管道。
base_image_path = os.path.join(base_path, "words")
def get_image_paths_and_labels(samples):
paths = []
corrected_samples = []
for i, file_line in enumerate(samples):
line_split = file_line.strip()
line_split = line_split.split(" ")
# Each line split will have this format for the corresponding image:
# part1/part1-part2/part1-part2-part3.png
image_name = line_split[0]
partI = image_name.split("-")[0]
partII = image_name.split("-")[1]
img_path = os.path.join(
base_image_path, partI, partI + "-" + partII, image_name + ".png"
)
if os.path.getsize(img_path):
paths.append(img_path)
corrected_samples.append(file_line.split("\n")[0])
return paths, corrected_samples
train_img_paths, train_labels = get_image_paths_and_labels(train_samples)
validation_img_paths, validation_labels = get_image_paths_and_labels(validation_samples)
test_img_paths, test_labels = get_image_paths_and_labels(test_samples)
然后我们准备真实标签。
# Find maximum length and the size of the vocabulary in the training data.
train_labels_cleaned = []
characters = set()
max_len = 0
for label in train_labels:
label = label.split(" ")[-1].strip()
for char in label:
characters.add(char)
max_len = max(max_len, len(label))
train_labels_cleaned.append(label)
characters = sorted(list(characters))
print("Maximum length: ", max_len)
print("Vocab size: ", len(characters))
# Check some label samples.
train_labels_cleaned[:10]
Maximum length: 21
Vocab size: 78
['sure',
'he',
'during',
'of',
'booty',
'gastronomy',
'boy',
'The',
'and',
'in']
现在我们也清理验证和测试标签。
def clean_labels(labels):
cleaned_labels = []
for label in labels:
label = label.split(" ")[-1].strip()
cleaned_labels.append(label)
return cleaned_labels
validation_labels_cleaned = clean_labels(validation_labels)
test_labels_cleaned = clean_labels(test_labels)
Keras 提供了不同的预处理层来处理不同的数据模态。 本指南 提供了全面的介绍。我们的示例涉及在字符级别预处理标签。这意味着如果有两个标签,例如 “cat” 和 “dog”,那么我们的字符词汇表应该是 {a, c, d, g, o, t}(没有任何特殊标记)。我们使用 StringLookup
层来实现此目的。
AUTOTUNE = tf.data.AUTOTUNE
# Mapping characters to integers.
char_to_num = StringLookup(vocabulary=list(characters), mask_token=None)
# Mapping integers back to original characters.
num_to_char = StringLookup(
vocabulary=char_to_num.get_vocabulary(), mask_token=None, invert=True
)
许多 OCR 模型使用矩形图像,而不是正方形图像。当我们可视化数据集中的一些样本时,这一点会变得更清晰。虽然不考虑宽高比调整正方形图像大小不会引入大量的失真,但矩形图像并非如此。但是将图像调整为统一大小是小批量处理的要求。因此,我们需要执行调整大小,以满足以下标准
def distortion_free_resize(image, img_size):
w, h = img_size
image = tf.image.resize(image, size=(h, w), preserve_aspect_ratio=True)
# Check tha amount of padding needed to be done.
pad_height = h - ops.shape(image)[0]
pad_width = w - ops.shape(image)[1]
# Only necessary if you want to do same amount of padding on both sides.
if pad_height % 2 != 0:
height = pad_height // 2
pad_height_top = height + 1
pad_height_bottom = height
else:
pad_height_top = pad_height_bottom = pad_height // 2
if pad_width % 2 != 0:
width = pad_width // 2
pad_width_left = width + 1
pad_width_right = width
else:
pad_width_left = pad_width_right = pad_width // 2
image = tf.pad(
image,
paddings=[
[pad_height_top, pad_height_bottom],
[pad_width_left, pad_width_right],
[0, 0],
],
)
image = ops.transpose(image, (1, 0, 2))
image = tf.image.flip_left_right(image)
return image
如果我们只是进行简单的调整大小,那么图像看起来会像这样
请注意,此调整大小将如何引入不必要的拉伸。
batch_size = 64
padding_token = 99
image_width = 128
image_height = 32
def preprocess_image(image_path, img_size=(image_width, image_height)):
image = tf.io.read_file(image_path)
image = tf.image.decode_png(image, 1)
image = distortion_free_resize(image, img_size)
image = ops.cast(image, tf.float32) / 255.0
return image
def vectorize_label(label):
label = char_to_num(tf.strings.unicode_split(label, input_encoding="UTF-8"))
length = ops.shape(label)[0]
pad_amount = max_len - length
label = tf.pad(label, paddings=[[0, pad_amount]], constant_values=padding_token)
return label
def process_images_labels(image_path, label):
image = preprocess_image(image_path)
label = vectorize_label(label)
return {"image": image, "label": label}
def prepare_dataset(image_paths, labels):
dataset = tf.data.Dataset.from_tensor_slices((image_paths, labels)).map(
process_images_labels, num_parallel_calls=AUTOTUNE
)
return dataset.batch(batch_size).cache().prefetch(AUTOTUNE)
tf.data.Dataset
对象train_ds = prepare_dataset(train_img_paths, train_labels_cleaned)
validation_ds = prepare_dataset(validation_img_paths, validation_labels_cleaned)
test_ds = prepare_dataset(test_img_paths, test_labels_cleaned)
for data in train_ds.take(1):
images, labels = data["image"], data["label"]
_, ax = plt.subplots(4, 4, figsize=(15, 8))
for i in range(16):
img = images[i]
img = tf.image.flip_left_right(img)
img = ops.transpose(img, (1, 0, 2))
img = (img * 255.0).numpy().clip(0, 255).astype(np.uint8)
img = img[:, :, 0]
# Gather indices where label!= padding_token.
label = labels[i]
indices = tf.gather(label, tf.where(tf.math.not_equal(label, padding_token)))
# Convert to string.
label = tf.strings.reduce_join(num_to_char(indices))
label = label.numpy().decode("utf-8")
ax[i // 4, i % 4].imshow(img, cmap="gray")
ax[i // 4, i % 4].set_title(label)
ax[i // 4, i % 4].axis("off")
plt.show()
你会注意到,原始图像的内容尽可能忠实地保留了下来,并进行了相应的填充。
我们的模型将使用 CTC 损失作为端点层。有关 CTC 损失的详细理解,请参阅 此文章。
class CTCLayer(keras.layers.Layer):
def __init__(self, name=None):
super().__init__(name=name)
self.loss_fn = tf.keras.backend.ctc_batch_cost
def call(self, y_true, y_pred):
batch_len = ops.cast(ops.shape(y_true)[0], dtype="int64")
input_length = ops.cast(ops.shape(y_pred)[1], dtype="int64")
label_length = ops.cast(ops.shape(y_true)[1], dtype="int64")
input_length = input_length * ops.ones(shape=(batch_len, 1), dtype="int64")
label_length = label_length * ops.ones(shape=(batch_len, 1), dtype="int64")
loss = self.loss_fn(y_true, y_pred, input_length, label_length)
self.add_loss(loss)
# At test time, just return the computed predictions.
return y_pred
def build_model():
# Inputs to the model
input_img = keras.Input(shape=(image_width, image_height, 1), name="image")
labels = keras.layers.Input(name="label", shape=(None,))
# First conv block.
x = keras.layers.Conv2D(
32,
(3, 3),
activation="relu",
kernel_initializer="he_normal",
padding="same",
name="Conv1",
)(input_img)
x = keras.layers.MaxPooling2D((2, 2), name="pool1")(x)
# Second conv block.
x = keras.layers.Conv2D(
64,
(3, 3),
activation="relu",
kernel_initializer="he_normal",
padding="same",
name="Conv2",
)(x)
x = keras.layers.MaxPooling2D((2, 2), name="pool2")(x)
# We have used two max pool with pool size and strides 2.
# Hence, downsampled feature maps are 4x smaller. The number of
# filters in the last layer is 64. Reshape accordingly before
# passing the output to the RNN part of the model.
new_shape = ((image_width // 4), (image_height // 4) * 64)
x = keras.layers.Reshape(target_shape=new_shape, name="reshape")(x)
x = keras.layers.Dense(64, activation="relu", name="dense1")(x)
x = keras.layers.Dropout(0.2)(x)
# RNNs.
x = keras.layers.Bidirectional(
keras.layers.LSTM(128, return_sequences=True, dropout=0.25)
)(x)
x = keras.layers.Bidirectional(
keras.layers.LSTM(64, return_sequences=True, dropout=0.25)
)(x)
# +2 is to account for the two special tokens introduced by the CTC loss.
# The recommendation comes here: https://git.io/J0eXP.
x = keras.layers.Dense(
len(char_to_num.get_vocabulary()) + 2, activation="softmax", name="dense2"
)(x)
# Add CTC layer for calculating CTC loss at each step.
output = CTCLayer(name="ctc_loss")(labels, x)
# Define the model.
model = keras.models.Model(
inputs=[input_img, labels], outputs=output, name="handwriting_recognizer"
)
# Optimizer.
opt = keras.optimizers.Adam()
# Compile the model and return.
model.compile(optimizer=opt)
return model
# Get the model.
model = build_model()
model.summary()
Model: "handwriting_recognizer"
┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ Connected to ┃ ┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩ │ image (InputLayer) │ (None, 128, 32, │ 0 │ - │ │ │ 1) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ Conv1 (Conv2D) │ (None, 128, 32, │ 320 │ image[0][0] │ │ │ 32) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ pool1 │ (None, 64, 16, │ 0 │ Conv1[0][0] │ │ (MaxPooling2D) │ 32) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ Conv2 (Conv2D) │ (None, 64, 16, │ 18,496 │ pool1[0][0] │ │ │ 64) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ pool2 │ (None, 32, 8, 64) │ 0 │ Conv2[0][0] │ │ (MaxPooling2D) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ reshape (Reshape) │ (None, 32, 512) │ 0 │ pool2[0][0] │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ dense1 (Dense) │ (None, 32, 64) │ 32,832 │ reshape[0][0] │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ dropout (Dropout) │ (None, 32, 64) │ 0 │ dense1[0][0] │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ bidirectional │ (None, 32, 256) │ 197,632 │ dropout[0][0] │ │ (Bidirectional) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ bidirectional_1 │ (None, 32, 128) │ 164,352 │ bidirectional[0]… │ │ (Bidirectional) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ label (InputLayer) │ (None, None) │ 0 │ - │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ dense2 (Dense) │ (None, 32, 81) │ 10,449 │ bidirectional_1[… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ ctc_loss (CTCLayer) │ (None, 32, 81) │ 0 │ label[0][0], │ │ │ │ │ dense2[0][0] │ └─────────────────────┴───────────────────┴────────────┴───────────────────┘
Total params: 424,081 (1.62 MB)
Trainable params: 424,081 (1.62 MB)
Non-trainable params: 0 (0.00 B)
编辑距离是评估 OCR 模型最广泛使用的指标。在本节中,我们将实现它并将其用作回调来监控我们的模型。
我们首先为了方便起见,将验证图像及其标签分开。
validation_images = []
validation_labels = []
for batch in validation_ds:
validation_images.append(batch["image"])
validation_labels.append(batch["label"])
现在,我们创建一个回调来监视编辑距离。
def calculate_edit_distance(labels, predictions):
# Get a single batch and convert its labels to sparse tensors.
saprse_labels = ops.cast(tf.sparse.from_dense(labels), dtype=tf.int64)
# Make predictions and convert them to sparse tensors.
input_len = np.ones(predictions.shape[0]) * predictions.shape[1]
predictions_decoded = keras.ops.nn.ctc_decode(
predictions, sequence_lengths=input_len
)[0][0][:, :max_len]
sparse_predictions = ops.cast(
tf.sparse.from_dense(predictions_decoded), dtype=tf.int64
)
# Compute individual edit distances and average them out.
edit_distances = tf.edit_distance(
sparse_predictions, saprse_labels, normalize=False
)
return tf.reduce_mean(edit_distances)
class EditDistanceCallback(keras.callbacks.Callback):
def __init__(self, pred_model):
super().__init__()
self.prediction_model = pred_model
def on_epoch_end(self, epoch, logs=None):
edit_distances = []
for i in range(len(validation_images)):
labels = validation_labels[i]
predictions = self.prediction_model.predict(validation_images[i])
edit_distances.append(calculate_edit_distance(labels, predictions).numpy())
print(
f"Mean edit distance for epoch {epoch + 1}: {np.mean(edit_distances):.4f}"
)
现在我们准备开始模型训练。
epochs = 10 # To get good results this should be at least 50.
model = build_model()
prediction_model = keras.models.Model(
model.get_layer(name="image").output, model.get_layer(name="dense2").output
)
edit_distance_callback = EditDistanceCallback(prediction_model)
# Train the model.
history = model.fit(
train_ds,
validation_data=validation_ds,
epochs=epochs,
callbacks=[edit_distance_callback],
)
Epoch 1/10
1357/1357 ━━━━━━━━━━━━━━━━━━━━ 216s 157ms/step - loss: 1068.7206 - val_loss: 762.4462
Epoch 2/10
1357/1357 ━━━━━━━━━━━━━━━━━━━━ 215s 158ms/step - loss: 735.8929 - val_loss: 627.9722
Epoch 3/10
1357/1357 ━━━━━━━━━━━━━━━━━━━━ 211s 155ms/step - loss: 624.9929 - val_loss: 540.8905
Epoch 4/10
1357/1357 ━━━━━━━━━━━━━━━━━━━━ 208s 153ms/step - loss: 544.2097 - val_loss: 446.0919
Epoch 5/10
1357/1357 ━━━━━━━━━━━━━━━━━━━━ 213s 157ms/step - loss: 459.0329 - val_loss: 347.1689
Epoch 6/10
1357/1357 ━━━━━━━━━━━━━━━━━━━━ 210s 155ms/step - loss: 378.6367 - val_loss: 287.1726
Epoch 7/10
1357/1357 ━━━━━━━━━━━━━━━━━━━━ 211s 155ms/step - loss: 325.4126 - val_loss: 250.3677
Epoch 8/10
1357/1357 ━━━━━━━━━━━━━━━━━━━━ 209s 154ms/step - loss: 289.2796 - val_loss: 224.4595
Epoch 9/10
1357/1357 ━━━━━━━━━━━━━━━━━━━━ 209s 154ms/step - loss: 264.0461 - val_loss: 205.5910
Epoch 10/10
1357/1357 ━━━━━━━━━━━━━━━━━━━━ 208s 153ms/step - loss: 245.5216 - val_loss: 195.7952
</div>
---
## Inference
```python
# A utility function to decode the output of the network.
def decode_batch_predictions(pred):
input_len = np.ones(pred.shape[0]) * pred.shape[1]
# Use greedy search. For complex tasks, you can use beam search.
results = keras.ops.nn.ctc_decode(pred, sequence_lengths=input_len)[0][0][
:, :max_len
]
# Iterate over the results and get back the text.
output_text = []
for res in results:
res = tf.gather(res, tf.where(tf.math.not_equal(res, -1)))
res = (
tf.strings.reduce_join(num_to_char(res))
.numpy()
.decode("utf-8")
.replace("[UNK]", "")
)
output_text.append(res)
return output_text
# Let's check results on some test samples.
for batch in test_ds.take(1):
batch_images = batch["image"]
_, ax = plt.subplots(4, 4, figsize=(15, 8))
preds = prediction_model.predict(batch_images)
pred_texts = decode_batch_predictions(preds)
for i in range(16):
img = batch_images[i]
img = tf.image.flip_left_right(img)
img = ops.transpose(img, (1, 0, 2))
img = (img * 255.0).numpy().clip(0, 255).astype(np.uint8)
img = img[:, :, 0]
title = f"Prediction: {pred_texts[i]}"
ax[i // 4, i % 4].imshow(img, cmap="gray")
ax[i // 4, i % 4].set_title(title)
ax[i // 4, i % 4].axis("off")
plt.show()