作者: Mohamed Reda Bouadjenek 和 Ngoc Dung Huynh
创建日期 2021/09/26
上次修改 2021/09/26
描述:训练一个基于 CTC 的自动语音识别模型。
语音识别是计算机科学和计算语言学的一个跨学科子领域,它开发使计算机能够识别和翻译口语为文本的方法和技术。它也被称为自动语音识别 (ASR)、计算机语音识别或语音到文本 (STT)。它融合了计算机科学、语言学和计算机工程领域的知识和研究。
本演示展示了如何将 2D CNN、RNN 和连接时序分类 (CTC) 损失结合起来构建 ASR。CTC 是一种用于训练深度神经网络进行语音识别、手写识别和其他序列问题的算法。当我们不知道输入如何与输出对齐(文本中的字符如何与音频对齐)时,就会使用 CTC。我们创建的模型类似于 DeepSpeech2。
我们将使用来自 LibriVox 项目的 LJSpeech 数据集。它包含单个说话者朗读 7 本非虚构书籍节选的简短音频剪辑。
我们将使用 词错率 (WER) 来评估模型的质量。WER 通过将识别出的单词序列中出现的替换、插入和删除加起来得到。将该数字除以最初说出的单词总数。结果就是 WER。要获得 WER 分数,您需要安装 jiwer 包。您可以使用以下命令行
pip install jiwer
import pandas as pd
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt
from IPython import display
from jiwer import wer
让我们下载 LJSpeech 数据集。该数据集包含 13,100 个音频文件,以 wav
文件的形式存储在 /wavs/
文件夹中。每个音频文件的标签(文本)是 metadata.csv
在本演示中,我们将使用“Normalized transcription”字段。
每个音频文件都是单通道 16 位 PCM WAV,采样率为 22,050 Hz。
data_url = "https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2"
data_path = keras.utils.get_file("LJSpeech-1.1", data_url, untar=True)
wavs_path = data_path + "/wavs/"
metadata_path = data_path + "/metadata.csv"
# Read metadata file and parse it
metadata_df = pd.read_csv(metadata_path, sep="|", header=None, quoting=3)
metadata_df.columns = ["file_name", "transcription", "normalized_transcription"]
metadata_df = metadata_df[["file_name", "normalized_transcription"]]
metadata_df = metadata_df.sample(frac=1).reset_index(drop=True)
文件名 | 标准化文本 | |
0 | LJ029-0199 | 十一月十八日,达拉斯市议会…… |
1 | LJ028-0237 | 命令他们沿着河床进入城镇…… |
2 | LJ009-0116 | 第二天,在…… |
split = int(len(metadata_df) * 0.90)
df_train = metadata_df[:split]
df_val = metadata_df[split:]
print(f"Size of the training set: {len(df_train)}")
print(f"Size of the training set: {len(df_val)}")
Size of the training set: 11790
Size of the training set: 1310
# The set of characters accepted in the transcription.
characters = [x for x in "abcdefghijklmnopqrstuvwxyz'?! "]
# Mapping characters to integers
char_to_num = keras.layers.StringLookup(vocabulary=characters, oov_token="")
# Mapping integers back to original characters
num_to_char = keras.layers.StringLookup(
vocabulary=char_to_num.get_vocabulary(), oov_token="", invert=True
f"The vocabulary is: {char_to_num.get_vocabulary()} "
f"(size ={char_to_num.vocabulary_size()})"
The vocabulary is: ['', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', "'", '?', '!', ' '] (size =31)
# An integer scalar Tensor. The window length in samples.
frame_length = 256
# An integer scalar Tensor. The number of samples to step.
frame_step = 160
# An integer scalar Tensor. The size of the FFT to apply.
# If not provided, uses the smallest power of 2 enclosing frame_length.
fft_length = 384
def encode_single_sample(wav_file, label):
## Process the Audio
# 1. Read wav file
file = tf.io.read_file(wavs_path + wav_file + ".wav")
# 2. Decode the wav file
audio, _ = tf.audio.decode_wav(file)
audio = tf.squeeze(audio, axis=-1)
# 3. Change type to float
audio = tf.cast(audio, tf.float32)
# 4. Get the spectrogram
spectrogram = tf.signal.stft(
audio, frame_length=frame_length, frame_step=frame_step, fft_length=fft_length
# 5. We only need the magnitude, which can be derived by applying tf.abs
spectrogram = tf.abs(spectrogram)
spectrogram = tf.math.pow(spectrogram, 0.5)
# 6. normalisation
means = tf.math.reduce_mean(spectrogram, 1, keepdims=True)
stddevs = tf.math.reduce_std(spectrogram, 1, keepdims=True)
spectrogram = (spectrogram - means) / (stddevs + 1e-10)
## Process the label
# 7. Convert label to Lower case
label = tf.strings.lower(label)
# 8. Split the label
label = tf.strings.unicode_split(label, input_encoding="UTF-8")
# 9. Map the characters in label to numbers
label = char_to_num(label)
# 10. Return a dict as our model is expecting two inputs
return spectrogram, label
对象我们创建一个 tf.data.Dataset
batch_size = 32
# Define the training dataset
train_dataset = tf.data.Dataset.from_tensor_slices(
(list(df_train["file_name"]), list(df_train["normalized_transcription"]))
train_dataset = (
train_dataset.map(encode_single_sample, num_parallel_calls=tf.data.AUTOTUNE)
# Define the validation dataset
validation_dataset = tf.data.Dataset.from_tensor_slices(
(list(df_val["file_name"]), list(df_val["normalized_transcription"]))
validation_dataset = (
validation_dataset.map(encode_single_sample, num_parallel_calls=tf.data.AUTOTUNE)
fig = plt.figure(figsize=(8, 5))
for batch in train_dataset.take(1):
spectrogram = batch[0][0].numpy()
spectrogram = np.array([np.trim_zeros(x) for x in np.transpose(spectrogram)])
label = batch[1][0]
# Spectrogram
label = tf.strings.reduce_join(num_to_char(label)).numpy().decode("utf-8")
ax = plt.subplot(2, 1, 1)
ax.imshow(spectrogram, vmax=1)
# Wav
file = tf.io.read_file(wavs_path + list(df_train["file_name"])[0] + ".wav")
audio, _ = tf.audio.decode_wav(file)
audio = audio.numpy()
ax = plt.subplot(2, 1, 2)
ax.set_title("Signal Wave")
ax.set_xlim(0, len(audio))
display.display(display.Audio(np.transpose(audio), rate=16000))
我们首先定义 CTC 损失函数。
def CTCLoss(y_true, y_pred):
# Compute the training-time loss value
batch_len = tf.cast(tf.shape(y_true)[0], dtype="int64")
input_length = tf.cast(tf.shape(y_pred)[1], dtype="int64")
label_length = tf.cast(tf.shape(y_true)[1], dtype="int64")
input_length = input_length * tf.ones(shape=(batch_len, 1), dtype="int64")
label_length = label_length * tf.ones(shape=(batch_len, 1), dtype="int64")
loss = keras.backend.ctc_batch_cost(y_true, y_pred, input_length, label_length)
return loss
现在我们定义我们的模型。我们将定义一个类似于 DeepSpeech2 的模型。
def build_model(input_dim, output_dim, rnn_layers=5, rnn_units=128):
"""Model similar to DeepSpeech2."""
# Model's input
input_spectrogram = layers.Input((None, input_dim), name="input")
# Expand the dimension to use 2D CNN.
x = layers.Reshape((-1, input_dim, 1), name="expand_dim")(input_spectrogram)
# Convolution layer 1
x = layers.Conv2D(
kernel_size=[11, 41],
strides=[2, 2],
x = layers.BatchNormalization(name="conv_1_bn")(x)
x = layers.ReLU(name="conv_1_relu")(x)
# Convolution layer 2
x = layers.Conv2D(
kernel_size=[11, 21],
strides=[1, 2],
x = layers.BatchNormalization(name="conv_2_bn")(x)
x = layers.ReLU(name="conv_2_relu")(x)
# Reshape the resulted volume to feed the RNNs layers
x = layers.Reshape((-1, x.shape[-2] * x.shape[-1]))(x)
# RNN layers
for i in range(1, rnn_layers + 1):
recurrent = layers.GRU(
x = layers.Bidirectional(
recurrent, name=f"bidirectional_{i}", merge_mode="concat"
if i < rnn_layers:
x = layers.Dropout(rate=0.5)(x)
# Dense layer
x = layers.Dense(units=rnn_units * 2, name="dense_1")(x)
x = layers.ReLU(name="dense_1_relu")(x)
x = layers.Dropout(rate=0.5)(x)
# Classification layer
output = layers.Dense(units=output_dim + 1, activation="softmax")(x)
# Model
model = keras.Model(input_spectrogram, output, name="DeepSpeech_2")
# Optimizer
opt = keras.optimizers.Adam(learning_rate=1e-4)
# Compile the model and return
model.compile(optimizer=opt, loss=CTCLoss)
return model
# Get the model
model = build_model(
input_dim=fft_length // 2 + 1,
Model: "DeepSpeech_2"
Layer (type) Output Shape Param #
input (InputLayer) [(None, None, 193)] 0
expand_dim (Reshape) (None, None, 193, 1) 0
conv_1 (Conv2D) (None, None, 97, 32) 14432
conv_1_bn (BatchNormalization) (None, None, 97, 32) 128
conv_1_relu (ReLU) (None, None, 97, 32) 0
conv_2 (Conv2D) (None, None, 49, 32) 236544
conv_2_bn (BatchNormalization) (None, None, 49, 32) 128
conv_2_relu (ReLU) (None, None, 49, 32) 0
reshape (Reshape) (None, None, 1568) 0
bidirectional_1 (Bidirectional) (None, None, 1024) 6395904
dropout (Dropout) (None, None, 1024) 0
bidirectional_2 (Bidirectional) (None, None, 1024) 4724736
dropout_1 (Dropout) (None, None, 1024) 0
bidirectional_3 (Bidirectional) (None, None, 1024) 4724736
dropout_2 (Dropout) (None, None, 1024) 0
bidirectional_4 (Bidirectional) (None, None, 1024) 4724736
dropout_3 (Dropout) (None, None, 1024) 0
bidirectional_5 (Bidirectional) (None, None, 1024) 4724736
dense_1 (Dense) (None, None, 1024) 1049600
dense_1_relu (ReLU) (None, None, 1024) 0
dropout_4 (Dropout) (None, None, 1024) 0
dense (Dense) (None, None, 32) 32800
Total params: 26,628,480
Trainable params: 26,628,352
Non-trainable params: 128
# A utility function to decode the output of the network
def decode_batch_predictions(pred):
input_len = np.ones(pred.shape[0]) * pred.shape[1]
# Use greedy search. For complex tasks, you can use beam search
results = keras.backend.ctc_decode(pred, input_length=input_len, greedy=True)[0][0]
# Iterate over the results and get back the text
output_text = []
for result in results:
result = tf.strings.reduce_join(num_to_char(result)).numpy().decode("utf-8")
return output_text
# A callback class to output a few transcriptions during training
class CallbackEval(keras.callbacks.Callback):
"""Displays a batch of outputs after every epoch."""
def __init__(self, dataset):
self.dataset = dataset
def on_epoch_end(self, epoch: int, logs=None):
predictions = []
targets = []
for batch in self.dataset:
X, y = batch
batch_predictions = model.predict(X)
batch_predictions = decode_batch_predictions(batch_predictions)
for label in y:
label = (
wer_score = wer(targets, predictions)
print("-" * 100)
print(f"Word Error Rate: {wer_score:.4f}")
print("-" * 100)
for i in np.random.randint(0, len(predictions), 2):
print(f"Target : {targets[i]}")
print(f"Prediction: {predictions[i]}")
print("-" * 100)
# Define the number of epochs.
epochs = 1
# Callback function to check transcription on the val set.
validation_callback = CallbackEval(validation_dataset)
# Train the model
history = model.fit(
369/369 [==============================] - ETA: 0s - loss: 302.4755----------------------------------------------------------------------------------------------------
Word Error Rate: 1.0000
Target : special agent lyndal l shaneyfelt a photography expert with the fbi
Prediction: s
Target : dissolved in water the sugar is transported down delicate tubes chiefly in the growing bark region of the stem
Prediction: sss
369/369 [==============================] - 407s 1s/step - loss: 302.4755 - val_loss: 252.1534
# Let's check results on more validation samples
predictions = []
targets = []
for batch in validation_dataset:
X, y = batch
batch_predictions = model.predict(X)
batch_predictions = decode_batch_predictions(batch_predictions)
for label in y:
label = tf.strings.reduce_join(num_to_char(label)).numpy().decode("utf-8")
wer_score = wer(targets, predictions)
print("-" * 100)
print(f"Word Error Rate: {wer_score:.4f}")
print("-" * 100)
for i in np.random.randint(0, len(predictions), 5):
print(f"Target : {targets[i]}")
print(f"Prediction: {predictions[i]}")
print("-" * 100)
Word Error Rate: 1.0000
Target : the owners of the latter would then issue a second set of warrants on these goods in total ignorance of the fact that they were already pledged
Prediction: ssnssss
Target : till the whole body of the slaves were manumitted in eighteen thirtythree
Prediction: sr
Target : the committee most of all insisted upon the entire individual separation of prisoners except during the hours of labor
Prediction: ssssss
Target : he made no attempt to help her and there are other indications that he did not want her to learn that language
Prediction: s
Target : the building of the babylon so famous in history began with nabopolassar
Prediction: sssrs
在实践中,您应该训练大约 50 个或更多轮次。每个轮次使用 GeForce RTX 2080 Ti
GPU 大约需要 5-6 分钟。我们在 50 个轮次训练的模型具有 词错率 (WER) ≈ 16% 到 17%
大约第 50 轮的一些转录结果
- Target : sir thomas overbury was undoubtedly poisoned by lord rochester in the reign
of james the first
- Prediction: cer thomas overbery was undoubtedly poisoned by lordrochester in the reign
of james the first
- Target : the committee does not seem to have yet understood that newgate could be
only and properly replaced
- Prediction: the committee does not seem to have yet understood that newgate could be
only and proberly replace
- Target : still no sentence of death was carried out for the offense and in eighteen
- Prediction: still no sentence of death was carried out for the offense and in eighteen