代码示例 / 音频数据 / 使用特征匹配的基于 MelGAN 的频谱图反演

使用特征匹配的基于 MelGAN 的频谱图反演

作者: Darshan Deshpande
创建日期 02/09/2021
最后修改日期 15/09/2021

ⓘ 此示例使用 Keras 2

在 Colab 中查看 GitHub 源码

描述: 使用 MelGAN 架构和特征匹配从梅尔频谱图进行音频反演。


简介

自回归声码器在语音处理历史的大部分时间里无处不在,但它们在很长一段时间内缺乏并行性。 MelGAN 是一种非自回归、全卷积声码器架构,用于频谱图反演、语音增强,以及当与 Tacotron2 或 FastSpeech 等将文本转换为梅尔频谱图的模型一起作为解码器使用时,用于当今最先进的语音合成。

在本教程中,我们将介绍 MelGAN 架构以及它如何实现快速频谱图反演,即从频谱图转换为音频波形。本教程中实现的 MelGAN 与原始实现类似,唯一的区别在于卷积的填充方法,我们将使用 'same' 而不是 reflect 填充。


导入和定义超参数

!pip install -qqq tensorflow_addons
!pip install -qqq tensorflow-io
import tensorflow as tf
import tensorflow_io as tfio
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow_addons import layers as addon_layers

# Setting logger level to avoid input shape warnings
tf.get_logger().setLevel("ERROR")

# Defining hyperparameters

DESIRED_SAMPLES = 8192
LEARNING_RATE_GEN = 1e-5
LEARNING_RATE_DISC = 1e-6
BATCH_SIZE = 16

mse = keras.losses.MeanSquaredError()
mae = keras.losses.MeanAbsoluteError()
|████████████████████████████████| 1.1 MB 5.1 MB/s 
|████████████████████████████████| 22.7 MB 1.7 MB/s 
|████████████████████████████████| 2.1 MB 36.2 MB/s 

加载数据集

此示例使用了 LJSpeech 数据集

LJSpeech 数据集主要用于文本到语音转换,包含 13,100 个离散语音样本,取自 7 本非虚构类书籍,总时长约 24 小时。 MelGAN 训练仅关注音频波形,因此我们只处理 WAV 文件,忽略音频注释。

!wget https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2
!tar -xf /content/LJSpeech-1.1.tar.bz2
--2021-09-16 11:45:24--  https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2
Resolving data.keithito.com (data.keithito.com)... 174.138.79.61
Connecting to data.keithito.com (data.keithito.com)|174.138.79.61|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 2748572632 (2.6G) [application/octet-stream]
Saving to: ‘LJSpeech-1.1.tar.bz2’
LJSpeech-1.1.tar.bz 100%[===================>]   2.56G  68.3MB/s    in 36s     
2021-09-16 11:46:01 (72.2 MB/s) - ‘LJSpeech-1.1.tar.bz2’ saved [2748572632/2748572632]

我们创建一个 tf.data.Dataset 来动态加载和处理音频文件。 preprocess() 函数以文件路径作为输入,并返回两个波形实例,一个作为输入,一个作为用于比较的真实值。输入波形将使用自定义的 MelSpec 层映射到频谱图,如下例所示。

# Splitting the dataset into training and testing splits
wavs = tf.io.gfile.glob("LJSpeech-1.1/wavs/*.wav")
print(f"Number of audio files: {len(wavs)}")

# Mapper function for loading the audio. This function returns two instances of the wave
def preprocess(filename):
    audio = tf.audio.decode_wav(tf.io.read_file(filename), 1, DESIRED_SAMPLES).audio
    return audio, audio


# Create tf.data.Dataset objects and apply preprocessing
train_dataset = tf.data.Dataset.from_tensor_slices((wavs,))
train_dataset = train_dataset.map(preprocess, num_parallel_calls=tf.data.AUTOTUNE)
Number of audio files: 13100

为 MelGAN 定义自定义层

MelGAN 架构包含 3 个主要模块

  1. 残差块
  2. 空洞卷积块
  3. 判别器块

MelGAN

由于网络以梅尔频谱图作为输入,我们将创建一个额外的自定义层,该层可以动态地将原始音频波形转换为频谱图。我们使用 train_dataset 中的原始音频张量,并通过下面的 MelSpec 层将其映射到梅尔频谱图。

# Custom keras layer for on-the-fly audio to spectrogram conversion


class MelSpec(layers.Layer):
    def __init__(
        self,
        frame_length=1024,
        frame_step=256,
        fft_length=None,
        sampling_rate=22050,
        num_mel_channels=80,
        freq_min=125,
        freq_max=7600,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.frame_length = frame_length
        self.frame_step = frame_step
        self.fft_length = fft_length
        self.sampling_rate = sampling_rate
        self.num_mel_channels = num_mel_channels
        self.freq_min = freq_min
        self.freq_max = freq_max
        # Defining mel filter. This filter will be multiplied with the STFT output
        self.mel_filterbank = tf.signal.linear_to_mel_weight_matrix(
            num_mel_bins=self.num_mel_channels,
            num_spectrogram_bins=self.frame_length // 2 + 1,
            sample_rate=self.sampling_rate,
            lower_edge_hertz=self.freq_min,
            upper_edge_hertz=self.freq_max,
        )

    def call(self, audio, training=True):
        # We will only perform the transformation during training.
        if training:
            # Taking the Short Time Fourier Transform. Ensure that the audio is padded.
            # In the paper, the STFT output is padded using the 'REFLECT' strategy.
            stft = tf.signal.stft(
                tf.squeeze(audio, -1),
                self.frame_length,
                self.frame_step,
                self.fft_length,
                pad_end=True,
            )

            # Taking the magnitude of the STFT output
            magnitude = tf.abs(stft)

            # Multiplying the Mel-filterbank with the magnitude and scaling it using the db scale
            mel = tf.matmul(tf.square(magnitude), self.mel_filterbank)
            log_mel_spec = tfio.audio.dbscale(mel, top_db=80)
            return log_mel_spec
        else:
            return audio

    def get_config(self):
        config = super().get_config()
        config.update(
            {
                "frame_length": self.frame_length,
                "frame_step": self.frame_step,
                "fft_length": self.fft_length,
                "sampling_rate": self.sampling_rate,
                "num_mel_channels": self.num_mel_channels,
                "freq_min": self.freq_min,
                "freq_max": self.freq_max,
            }
        )
        return config

残差卷积块广泛使用空洞卷积,每个块的总感受野为 27 个时间步。空洞必须是 kernel_size 的幂次方增长,以确保减少输出中的嘶嘶声。论文提出的网络如下所示:

ConvBlock

# Creating the residual stack block


def residual_stack(input, filters):
    """Convolutional residual stack with weight normalization.

    Args:
        filters: int, determines filter size for the residual stack.

    Returns:
        Residual stack output.
    """
    c1 = addon_layers.WeightNormalization(
        layers.Conv1D(filters, 3, dilation_rate=1, padding="same"), data_init=False
    )(input)
    lrelu1 = layers.LeakyReLU()(c1)
    c2 = addon_layers.WeightNormalization(
        layers.Conv1D(filters, 3, dilation_rate=1, padding="same"), data_init=False
    )(lrelu1)
    add1 = layers.Add()([c2, input])

    lrelu2 = layers.LeakyReLU()(add1)
    c3 = addon_layers.WeightNormalization(
        layers.Conv1D(filters, 3, dilation_rate=3, padding="same"), data_init=False
    )(lrelu2)
    lrelu3 = layers.LeakyReLU()(c3)
    c4 = addon_layers.WeightNormalization(
        layers.Conv1D(filters, 3, dilation_rate=1, padding="same"), data_init=False
    )(lrelu3)
    add2 = layers.Add()([add1, c4])

    lrelu4 = layers.LeakyReLU()(add2)
    c5 = addon_layers.WeightNormalization(
        layers.Conv1D(filters, 3, dilation_rate=9, padding="same"), data_init=False
    )(lrelu4)
    lrelu5 = layers.LeakyReLU()(c5)
    c6 = addon_layers.WeightNormalization(
        layers.Conv1D(filters, 3, dilation_rate=1, padding="same"), data_init=False
    )(lrelu5)
    add3 = layers.Add()([c6, add2])

    return add3

每个卷积块都使用残差堆栈提供的空洞卷积,并通过 upsampling_factor 对输入数据进行上采样。

# Dilated convolutional block consisting of the Residual stack


def conv_block(input, conv_dim, upsampling_factor):
    """Dilated Convolutional Block with weight normalization.

    Args:
        conv_dim: int, determines filter size for the block.
        upsampling_factor: int, scale for upsampling.

    Returns:
        Dilated convolution block.
    """
    conv_t = addon_layers.WeightNormalization(
        layers.Conv1DTranspose(conv_dim, 16, upsampling_factor, padding="same"),
        data_init=False,
    )(input)
    lrelu1 = layers.LeakyReLU()(conv_t)
    res_stack = residual_stack(lrelu1, conv_dim)
    lrelu2 = layers.LeakyReLU()(res_stack)
    return lrelu2

判别器块由卷积层和下采样层组成。此块对于特征匹配技术的实现至关重要。

每个判别器都输出一个特征图列表,在训练期间将对这些列表进行比较以计算特征匹配损失。

def discriminator_block(input):
    conv1 = addon_layers.WeightNormalization(
        layers.Conv1D(16, 15, 1, "same"), data_init=False
    )(input)
    lrelu1 = layers.LeakyReLU()(conv1)
    conv2 = addon_layers.WeightNormalization(
        layers.Conv1D(64, 41, 4, "same", groups=4), data_init=False
    )(lrelu1)
    lrelu2 = layers.LeakyReLU()(conv2)
    conv3 = addon_layers.WeightNormalization(
        layers.Conv1D(256, 41, 4, "same", groups=16), data_init=False
    )(lrelu2)
    lrelu3 = layers.LeakyReLU()(conv3)
    conv4 = addon_layers.WeightNormalization(
        layers.Conv1D(1024, 41, 4, "same", groups=64), data_init=False
    )(lrelu3)
    lrelu4 = layers.LeakyReLU()(conv4)
    conv5 = addon_layers.WeightNormalization(
        layers.Conv1D(1024, 41, 4, "same", groups=256), data_init=False
    )(lrelu4)
    lrelu5 = layers.LeakyReLU()(conv5)
    conv6 = addon_layers.WeightNormalization(
        layers.Conv1D(1024, 5, 1, "same"), data_init=False
    )(lrelu5)
    lrelu6 = layers.LeakyReLU()(conv6)
    conv7 = addon_layers.WeightNormalization(
        layers.Conv1D(1, 3, 1, "same"), data_init=False
    )(lrelu6)
    return [lrelu1, lrelu2, lrelu3, lrelu4, lrelu5, lrelu6, conv7]

创建生成器

def create_generator(input_shape):
    inp = keras.Input(input_shape)
    x = MelSpec()(inp)
    x = layers.Conv1D(512, 7, padding="same")(x)
    x = layers.LeakyReLU()(x)
    x = conv_block(x, 256, 8)
    x = conv_block(x, 128, 8)
    x = conv_block(x, 64, 2)
    x = conv_block(x, 32, 2)
    x = addon_layers.WeightNormalization(
        layers.Conv1D(1, 7, padding="same", activation="tanh")
    )(x)
    return keras.Model(inp, x)


# We use a dynamic input shape for the generator since the model is fully convolutional
generator = create_generator((None, 1))
generator.summary()
Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            [(None, None, 1)]    0                                            
__________________________________________________________________________________________________
mel_spec (MelSpec)              (None, None, 80)     0           input_1[0][0]                    
__________________________________________________________________________________________________
conv1d (Conv1D)                 (None, None, 512)    287232      mel_spec[0][0]                   
__________________________________________________________________________________________________
leaky_re_lu (LeakyReLU)         (None, None, 512)    0           conv1d[0][0]                     
__________________________________________________________________________________________________
weight_normalization (WeightNor (None, None, 256)    2097921     leaky_re_lu[0][0]                
__________________________________________________________________________________________________
leaky_re_lu_1 (LeakyReLU)       (None, None, 256)    0           weight_normalization[0][0]       
__________________________________________________________________________________________________
weight_normalization_1 (WeightN (None, None, 256)    197121      leaky_re_lu_1[0][0]              
__________________________________________________________________________________________________
leaky_re_lu_2 (LeakyReLU)       (None, None, 256)    0           weight_normalization_1[0][0]     
__________________________________________________________________________________________________
weight_normalization_2 (WeightN (None, None, 256)    197121      leaky_re_lu_2[0][0]              
__________________________________________________________________________________________________
add (Add)                       (None, None, 256)    0           weight_normalization_2[0][0]     
                                                                 leaky_re_lu_1[0][0]              
__________________________________________________________________________________________________
leaky_re_lu_3 (LeakyReLU)       (None, None, 256)    0           add[0][0]                        
__________________________________________________________________________________________________
weight_normalization_3 (WeightN (None, None, 256)    197121      leaky_re_lu_3[0][0]              
__________________________________________________________________________________________________
leaky_re_lu_4 (LeakyReLU)       (None, None, 256)    0           weight_normalization_3[0][0]     
__________________________________________________________________________________________________
weight_normalization_4 (WeightN (None, None, 256)    197121      leaky_re_lu_4[0][0]              
__________________________________________________________________________________________________
add_1 (Add)                     (None, None, 256)    0           add[0][0]                        
                                                                 weight_normalization_4[0][0]     
__________________________________________________________________________________________________
leaky_re_lu_5 (LeakyReLU)       (None, None, 256)    0           add_1[0][0]                      
__________________________________________________________________________________________________
weight_normalization_5 (WeightN (None, None, 256)    197121      leaky_re_lu_5[0][0]              
__________________________________________________________________________________________________
leaky_re_lu_6 (LeakyReLU)       (None, None, 256)    0           weight_normalization_5[0][0]     
__________________________________________________________________________________________________
weight_normalization_6 (WeightN (None, None, 256)    197121      leaky_re_lu_6[0][0]              
__________________________________________________________________________________________________
add_2 (Add)                     (None, None, 256)    0           weight_normalization_6[0][0]     
                                                                 add_1[0][0]                      
__________________________________________________________________________________________________
leaky_re_lu_7 (LeakyReLU)       (None, None, 256)    0           add_2[0][0]                      
__________________________________________________________________________________________________
weight_normalization_7 (WeightN (None, None, 128)    524673      leaky_re_lu_7[0][0]              
__________________________________________________________________________________________________
leaky_re_lu_8 (LeakyReLU)       (None, None, 128)    0           weight_normalization_7[0][0]     
__________________________________________________________________________________________________
weight_normalization_8 (WeightN (None, None, 128)    49409       leaky_re_lu_8[0][0]              
__________________________________________________________________________________________________
leaky_re_lu_9 (LeakyReLU)       (None, None, 128)    0           weight_normalization_8[0][0]     
__________________________________________________________________________________________________
weight_normalization_9 (WeightN (None, None, 128)    49409       leaky_re_lu_9[0][0]              
__________________________________________________________________________________________________
add_3 (Add)                     (None, None, 128)    0           weight_normalization_9[0][0]     
                                                                 leaky_re_lu_8[0][0]              
__________________________________________________________________________________________________
leaky_re_lu_10 (LeakyReLU)      (None, None, 128)    0           add_3[0][0]                      
__________________________________________________________________________________________________
weight_normalization_10 (Weight (None, None, 128)    49409       leaky_re_lu_10[0][0]             
__________________________________________________________________________________________________
leaky_re_lu_11 (LeakyReLU)      (None, None, 128)    0           weight_normalization_10[0][0]    
__________________________________________________________________________________________________
weight_normalization_11 (Weight (None, None, 128)    49409       leaky_re_lu_11[0][0]             
__________________________________________________________________________________________________
add_4 (Add)                     (None, None, 128)    0           add_3[0][0]                      
                                                                 weight_normalization_11[0][0]    
__________________________________________________________________________________________________
leaky_re_lu_12 (LeakyReLU)      (None, None, 128)    0           add_4[0][0]                      
__________________________________________________________________________________________________
weight_normalization_12 (Weight (None, None, 128)    49409       leaky_re_lu_12[0][0]             
__________________________________________________________________________________________________
leaky_re_lu_13 (LeakyReLU)      (None, None, 128)    0           weight_normalization_12[0][0]    
__________________________________________________________________________________________________
weight_normalization_13 (Weight (None, None, 128)    49409       leaky_re_lu_13[0][0]             
__________________________________________________________________________________________________
add_5 (Add)                     (None, None, 128)    0           weight_normalization_13[0][0]    
                                                                 add_4[0][0]                      
__________________________________________________________________________________________________
leaky_re_lu_14 (LeakyReLU)      (None, None, 128)    0           add_5[0][0]                      
__________________________________________________________________________________________________
weight_normalization_14 (Weight (None, None, 64)     131265      leaky_re_lu_14[0][0]             
__________________________________________________________________________________________________
leaky_re_lu_15 (LeakyReLU)      (None, None, 64)     0           weight_normalization_14[0][0]    
__________________________________________________________________________________________________
weight_normalization_15 (Weight (None, None, 64)     12417       leaky_re_lu_15[0][0]             
__________________________________________________________________________________________________
leaky_re_lu_16 (LeakyReLU)      (None, None, 64)     0           weight_normalization_15[0][0]    
__________________________________________________________________________________________________
weight_normalization_16 (Weight (None, None, 64)     12417       leaky_re_lu_16[0][0]             
__________________________________________________________________________________________________
add_6 (Add)                     (None, None, 64)     0           weight_normalization_16[0][0]    
                                                                 leaky_re_lu_15[0][0]             
__________________________________________________________________________________________________
leaky_re_lu_17 (LeakyReLU)      (None, None, 64)     0           add_6[0][0]                      
__________________________________________________________________________________________________
weight_normalization_17 (Weight (None, None, 64)     12417       leaky_re_lu_17[0][0]             
__________________________________________________________________________________________________
leaky_re_lu_18 (LeakyReLU)      (None, None, 64)     0           weight_normalization_17[0][0]    
__________________________________________________________________________________________________
weight_normalization_18 (Weight (None, None, 64)     12417       leaky_re_lu_18[0][0]             
__________________________________________________________________________________________________
add_7 (Add)                     (None, None, 64)     0           add_6[0][0]                      
                                                                 weight_normalization_18[0][0]    
__________________________________________________________________________________________________
leaky_re_lu_19 (LeakyReLU)      (None, None, 64)     0           add_7[0][0]                      
__________________________________________________________________________________________________
weight_normalization_19 (Weight (None, None, 64)     12417       leaky_re_lu_19[0][0]             
__________________________________________________________________________________________________
leaky_re_lu_20 (LeakyReLU)      (None, None, 64)     0           weight_normalization_19[0][0]    
__________________________________________________________________________________________________
weight_normalization_20 (Weight (None, None, 64)     12417       leaky_re_lu_20[0][0]             
__________________________________________________________________________________________________
add_8 (Add)                     (None, None, 64)     0           weight_normalization_20[0][0]    
                                                                 add_7[0][0]                      
__________________________________________________________________________________________________
leaky_re_lu_21 (LeakyReLU)      (None, None, 64)     0           add_8[0][0]                      
__________________________________________________________________________________________________
weight_normalization_21 (Weight (None, None, 32)     32865       leaky_re_lu_21[0][0]             
__________________________________________________________________________________________________
leaky_re_lu_22 (LeakyReLU)      (None, None, 32)     0           weight_normalization_21[0][0]    
__________________________________________________________________________________________________
weight_normalization_22 (Weight (None, None, 32)     3137        leaky_re_lu_22[0][0]             
__________________________________________________________________________________________________
leaky_re_lu_23 (LeakyReLU)      (None, None, 32)     0           weight_normalization_22[0][0]    
__________________________________________________________________________________________________
weight_normalization_23 (Weight (None, None, 32)     3137        leaky_re_lu_23[0][0]             
__________________________________________________________________________________________________
add_9 (Add)                     (None, None, 32)     0           weight_normalization_23[0][0]    
                                                                 leaky_re_lu_22[0][0]             
__________________________________________________________________________________________________
leaky_re_lu_24 (LeakyReLU)      (None, None, 32)     0           add_9[0][0]                      
__________________________________________________________________________________________________
weight_normalization_24 (Weight (None, None, 32)     3137        leaky_re_lu_24[0][0]             
__________________________________________________________________________________________________
leaky_re_lu_25 (LeakyReLU)      (None, None, 32)     0           weight_normalization_24[0][0]    
__________________________________________________________________________________________________
weight_normalization_25 (Weight (None, None, 32)     3137        leaky_re_lu_25[0][0]             
__________________________________________________________________________________________________
add_10 (Add)                    (None, None, 32)     0           add_9[0][0]                      
                                                                 weight_normalization_25[0][0]    
__________________________________________________________________________________________________
leaky_re_lu_26 (LeakyReLU)      (None, None, 32)     0           add_10[0][0]                     
__________________________________________________________________________________________________
weight_normalization_26 (Weight (None, None, 32)     3137        leaky_re_lu_26[0][0]             
__________________________________________________________________________________________________
leaky_re_lu_27 (LeakyReLU)      (None, None, 32)     0           weight_normalization_26[0][0]    
__________________________________________________________________________________________________
weight_normalization_27 (Weight (None, None, 32)     3137        leaky_re_lu_27[0][0]             
__________________________________________________________________________________________________
add_11 (Add)                    (None, None, 32)     0           weight_normalization_27[0][0]    
                                                                 add_10[0][0]                     
__________________________________________________________________________________________________
leaky_re_lu_28 (LeakyReLU)      (None, None, 32)     0           add_11[0][0]                     
__________________________________________________________________________________________________
weight_normalization_28 (Weight (None, None, 1)      452         leaky_re_lu_28[0][0]             
==================================================================================================
Total params: 4,646,912
Trainable params: 4,646,658
Non-trainable params: 254
__________________________________________________________________________________________________

创建判别器

def create_discriminator(input_shape):
    inp = keras.Input(input_shape)
    out_map1 = discriminator_block(inp)
    pool1 = layers.AveragePooling1D()(inp)
    out_map2 = discriminator_block(pool1)
    pool2 = layers.AveragePooling1D()(pool1)
    out_map3 = discriminator_block(pool2)
    return keras.Model(inp, [out_map1, out_map2, out_map3])


# We use a dynamic input shape for the discriminator
# This is done because the input shape for the generator is unknown
discriminator = create_discriminator((None, 1))

discriminator.summary()
Model: "model_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_2 (InputLayer)            [(None, None, 1)]    0                                            
__________________________________________________________________________________________________
average_pooling1d (AveragePooli (None, None, 1)      0           input_2[0][0]                    
__________________________________________________________________________________________________
average_pooling1d_1 (AveragePoo (None, None, 1)      0           average_pooling1d[0][0]          
__________________________________________________________________________________________________
weight_normalization_29 (Weight (None, None, 16)     273         input_2[0][0]                    
__________________________________________________________________________________________________
weight_normalization_36 (Weight (None, None, 16)     273         average_pooling1d[0][0]          
__________________________________________________________________________________________________
weight_normalization_43 (Weight (None, None, 16)     273         average_pooling1d_1[0][0]        
__________________________________________________________________________________________________
leaky_re_lu_29 (LeakyReLU)      (None, None, 16)     0           weight_normalization_29[0][0]    
__________________________________________________________________________________________________
leaky_re_lu_35 (LeakyReLU)      (None, None, 16)     0           weight_normalization_36[0][0]    
__________________________________________________________________________________________________
leaky_re_lu_41 (LeakyReLU)      (None, None, 16)     0           weight_normalization_43[0][0]    
__________________________________________________________________________________________________
weight_normalization_30 (Weight (None, None, 64)     10625       leaky_re_lu_29[0][0]             
__________________________________________________________________________________________________
weight_normalization_37 (Weight (None, None, 64)     10625       leaky_re_lu_35[0][0]             
__________________________________________________________________________________________________
weight_normalization_44 (Weight (None, None, 64)     10625       leaky_re_lu_41[0][0]             
__________________________________________________________________________________________________
leaky_re_lu_30 (LeakyReLU)      (None, None, 64)     0           weight_normalization_30[0][0]    
__________________________________________________________________________________________________
leaky_re_lu_36 (LeakyReLU)      (None, None, 64)     0           weight_normalization_37[0][0]    
__________________________________________________________________________________________________
leaky_re_lu_42 (LeakyReLU)      (None, None, 64)     0           weight_normalization_44[0][0]    
__________________________________________________________________________________________________
weight_normalization_31 (Weight (None, None, 256)    42497       leaky_re_lu_30[0][0]             
__________________________________________________________________________________________________
weight_normalization_38 (Weight (None, None, 256)    42497       leaky_re_lu_36[0][0]             
__________________________________________________________________________________________________
weight_normalization_45 (Weight (None, None, 256)    42497       leaky_re_lu_42[0][0]             
__________________________________________________________________________________________________
leaky_re_lu_31 (LeakyReLU)      (None, None, 256)    0           weight_normalization_31[0][0]    
__________________________________________________________________________________________________
leaky_re_lu_37 (LeakyReLU)      (None, None, 256)    0           weight_normalization_38[0][0]    
__________________________________________________________________________________________________
leaky_re_lu_43 (LeakyReLU)      (None, None, 256)    0           weight_normalization_45[0][0]    
__________________________________________________________________________________________________
weight_normalization_32 (Weight (None, None, 1024)   169985      leaky_re_lu_31[0][0]             
__________________________________________________________________________________________________
weight_normalization_39 (Weight (None, None, 1024)   169985      leaky_re_lu_37[0][0]             
__________________________________________________________________________________________________
weight_normalization_46 (Weight (None, None, 1024)   169985      leaky_re_lu_43[0][0]             
__________________________________________________________________________________________________
leaky_re_lu_32 (LeakyReLU)      (None, None, 1024)   0           weight_normalization_32[0][0]    
__________________________________________________________________________________________________
leaky_re_lu_38 (LeakyReLU)      (None, None, 1024)   0           weight_normalization_39[0][0]    
__________________________________________________________________________________________________
leaky_re_lu_44 (LeakyReLU)      (None, None, 1024)   0           weight_normalization_46[0][0]    
__________________________________________________________________________________________________
weight_normalization_33 (Weight (None, None, 1024)   169985      leaky_re_lu_32[0][0]             
__________________________________________________________________________________________________
weight_normalization_40 (Weight (None, None, 1024)   169985      leaky_re_lu_38[0][0]             
__________________________________________________________________________________________________
weight_normalization_47 (Weight (None, None, 1024)   169985      leaky_re_lu_44[0][0]             
__________________________________________________________________________________________________
leaky_re_lu_33 (LeakyReLU)      (None, None, 1024)   0           weight_normalization_33[0][0]    
__________________________________________________________________________________________________
leaky_re_lu_39 (LeakyReLU)      (None, None, 1024)   0           weight_normalization_40[0][0]    
__________________________________________________________________________________________________
leaky_re_lu_45 (LeakyReLU)      (None, None, 1024)   0           weight_normalization_47[0][0]    
__________________________________________________________________________________________________
weight_normalization_34 (Weight (None, None, 1024)   5244929     leaky_re_lu_33[0][0]             
__________________________________________________________________________________________________
weight_normalization_41 (Weight (None, None, 1024)   5244929     leaky_re_lu_39[0][0]             
__________________________________________________________________________________________________
weight_normalization_48 (Weight (None, None, 1024)   5244929     leaky_re_lu_45[0][0]             
__________________________________________________________________________________________________
leaky_re_lu_34 (LeakyReLU)      (None, None, 1024)   0           weight_normalization_34[0][0]    
__________________________________________________________________________________________________
leaky_re_lu_40 (LeakyReLU)      (None, None, 1024)   0           weight_normalization_41[0][0]    
__________________________________________________________________________________________________
leaky_re_lu_46 (LeakyReLU)      (None, None, 1024)   0           weight_normalization_48[0][0]    
__________________________________________________________________________________________________
weight_normalization_35 (Weight (None, None, 1)      3075        leaky_re_lu_34[0][0]             
__________________________________________________________________________________________________
weight_normalization_42 (Weight (None, None, 1)      3075        leaky_re_lu_40[0][0]             
__________________________________________________________________________________________________
weight_normalization_49 (Weight (None, None, 1)      3075        leaky_re_lu_46[0][0]             
==================================================================================================
Total params: 16,924,107
Trainable params: 16,924,086
Non-trainable params: 21
__________________________________________________________________________________________________

定义损失函数

生成器损失

生成器架构结合了两种损失:

  1. 均方误差

这是标准的 MSE 生成器损失,在 N 层判别器的输出和全 1 之间计算。

  1. 特征匹配损失

此损失涉及从判别器的每一层提取生成器和真实数据的输出,并使用平均绝对误差比较每个层输出 k

判别器损失

判别器使用平均绝对误差,将真实数据的预测与 1 进行比较,将生成数据的预测与 0 进行比较。

# Generator loss


def generator_loss(real_pred, fake_pred):
    """Loss function for the generator.

    Args:
        real_pred: Tensor, output of the ground truth wave passed through the discriminator.
        fake_pred: Tensor, output of the generator prediction passed through the discriminator.

    Returns:
        Loss for the generator.
    """
    gen_loss = []
    for i in range(len(fake_pred)):
        gen_loss.append(mse(tf.ones_like(fake_pred[i][-1]), fake_pred[i][-1]))

    return tf.reduce_mean(gen_loss)


def feature_matching_loss(real_pred, fake_pred):
    """Implements the feature matching loss.

    Args:
        real_pred: Tensor, output of the ground truth wave passed through the discriminator.
        fake_pred: Tensor, output of the generator prediction passed through the discriminator.

    Returns:
        Feature Matching Loss.
    """
    fm_loss = []
    for i in range(len(fake_pred)):
        for j in range(len(fake_pred[i]) - 1):
            fm_loss.append(mae(real_pred[i][j], fake_pred[i][j]))

    return tf.reduce_mean(fm_loss)


def discriminator_loss(real_pred, fake_pred):
    """Implements the discriminator loss.

    Args:
        real_pred: Tensor, output of the ground truth wave passed through the discriminator.
        fake_pred: Tensor, output of the generator prediction passed through the discriminator.

    Returns:
        Discriminator Loss.
    """
    real_loss, fake_loss = [], []
    for i in range(len(real_pred)):
        real_loss.append(mse(tf.ones_like(real_pred[i][-1]), real_pred[i][-1]))
        fake_loss.append(mse(tf.zeros_like(fake_pred[i][-1]), fake_pred[i][-1]))

    # Calculating the final discriminator loss after scaling
    disc_loss = tf.reduce_mean(real_loss) + tf.reduce_mean(fake_loss)
    return disc_loss

定义用于训练的 MelGAN 模型。此子类重写 train_step() 方法以实现训练逻辑。

class MelGAN(keras.Model):
    def __init__(self, generator, discriminator, **kwargs):
        """MelGAN trainer class

        Args:
            generator: keras.Model, Generator model
            discriminator: keras.Model, Discriminator model
        """
        super().__init__(**kwargs)
        self.generator = generator
        self.discriminator = discriminator

    def compile(
        self,
        gen_optimizer,
        disc_optimizer,
        generator_loss,
        feature_matching_loss,
        discriminator_loss,
    ):
        """MelGAN compile method.

        Args:
            gen_optimizer: keras.optimizer, optimizer to be used for training
            disc_optimizer: keras.optimizer, optimizer to be used for training
            generator_loss: callable, loss function for generator
            feature_matching_loss: callable, loss function for feature matching
            discriminator_loss: callable, loss function for discriminator
        """
        super().compile()

        # Optimizers
        self.gen_optimizer = gen_optimizer
        self.disc_optimizer = disc_optimizer

        # Losses
        self.generator_loss = generator_loss
        self.feature_matching_loss = feature_matching_loss
        self.discriminator_loss = discriminator_loss

        # Trackers
        self.gen_loss_tracker = keras.metrics.Mean(name="gen_loss")
        self.disc_loss_tracker = keras.metrics.Mean(name="disc_loss")

    def train_step(self, batch):
        x_batch_train, y_batch_train = batch

        with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
            # Generating the audio wave
            gen_audio_wave = generator(x_batch_train, training=True)

            # Generating the features using the discriminator
            real_pred = discriminator(y_batch_train)
            fake_pred = discriminator(gen_audio_wave)

            # Calculating the generator losses
            gen_loss = generator_loss(real_pred, fake_pred)
            fm_loss = feature_matching_loss(real_pred, fake_pred)

            # Calculating final generator loss
            gen_fm_loss = gen_loss + 10 * fm_loss

            # Calculating the discriminator losses
            disc_loss = discriminator_loss(real_pred, fake_pred)

        # Calculating and applying the gradients for generator and discriminator
        grads_gen = gen_tape.gradient(gen_fm_loss, generator.trainable_weights)
        grads_disc = disc_tape.gradient(disc_loss, discriminator.trainable_weights)
        gen_optimizer.apply_gradients(zip(grads_gen, generator.trainable_weights))
        disc_optimizer.apply_gradients(zip(grads_disc, discriminator.trainable_weights))

        self.gen_loss_tracker.update_state(gen_fm_loss)
        self.disc_loss_tracker.update_state(disc_loss)

        return {
            "gen_loss": self.gen_loss_tracker.result(),
            "disc_loss": self.disc_loss_tracker.result(),
        }

训练

论文建议,使用动态形状进行训练大约需要 400,000 步(约 500 个 epoch)。在本例中,我们将只运行一个 epoch(819 步)。更长的训练时间(超过 300 个 epoch)几乎肯定会提供更好的结果。

gen_optimizer = keras.optimizers.Adam(
    LEARNING_RATE_GEN, beta_1=0.5, beta_2=0.9, clipnorm=1
)
disc_optimizer = keras.optimizers.Adam(
    LEARNING_RATE_DISC, beta_1=0.5, beta_2=0.9, clipnorm=1
)

# Start training
generator = create_generator((None, 1))
discriminator = create_discriminator((None, 1))

mel_gan = MelGAN(generator, discriminator)
mel_gan.compile(
    gen_optimizer,
    disc_optimizer,
    generator_loss,
    feature_matching_loss,
    discriminator_loss,
)
mel_gan.fit(
    train_dataset.shuffle(200).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE), epochs=1
)
819/819 [==============================] - 641s 696ms/step - gen_loss: 0.9761 - disc_loss: 0.9350

<keras.callbacks.History at 0x7f8f702fe050>

测试模型

训练好的模型现在可用于实时文本到语音转换任务。为了测试 MelGAN 的推理速度,我们取一个音频梅尔频谱图样本并进行转换。请注意,实际的模型管道将不包括 MelSpec 层,因此在推理时将禁用此层。推理输入将是一个经过类似 MelSpec 层配置处理的梅尔频谱图。

为了测试这一点,我们将创建一个随机均匀分布的张量来模拟推理管道的行为。

# Sampling a random tensor to mimic a batch of 128 spectrograms of shape [50, 80]
audio_sample = tf.random.uniform([128, 50, 80])

对单一样本的推理速度进行计时。运行此代码,您会发现 K80 GPU 上每个频谱图的平均推理时间在 8 毫秒到 10 毫秒之间,这非常快。

pred = generator.predict(audio_sample, batch_size=32, verbose=1)
4/4 [==============================] - 5s 280ms/step

结论

MelGAN 是一种非常有效的频谱图反演架构,其平均意见分数 (MOS) 为 3.61,显著优于 Griffin Lim 算法(MOS 仅为 1.57)。相比之下,MelGAN 在 LJSpeech 和 VCTK 数据集上的文本到语音和语音增强任务上与最先进的 WaveGlow 和 WaveNet 架构进行了比较 [1]

本教程重点介绍了

  1. 使用随滤波器大小增长的空洞卷积的优势
  2. 实现一个自定义层,用于动态地将音频波形转换为梅尔频谱图
  3. 使用特征匹配损失函数训练 GAN 生成器的有效性。

进一步阅读

  1. MelGAN 论文 (Kundan Kumar 等人),以了解架构和训练过程的原理。
  2. 有关特征匹配损失的深入理解,您可以参考 Improved Techniques for Training GANs (Tim Salimans 等人)。

HuggingFace 上提供的示例

训练好的模型 演示
Generic badge Generic badge