BatchNormalization 类keras.layers.BatchNormalization(
axis=-1,
momentum=0.99,
epsilon=0.001,
center=True,
scale=True,
beta_initializer="zeros",
gamma_initializer="ones",
moving_mean_initializer="zeros",
moving_variance_initializer="ones",
beta_regularizer=None,
gamma_regularizer=None,
beta_constraint=None,
gamma_constraint=None,
synchronized=False,
**kwargs
)
对输入进行归一化的层。
Batch Normalization 应用一种变换,使输出的均值接近 0,输出的标准差接近 1。
重要的是,Batch Normalization 在训练和推理时工作方式不同。
在训练时 (即使用 fit() 或在调用层/模型时附带参数 training=True),该层会使用当前输入批次的均值和标准差对其输出进行归一化。也就是说,对于每个正在归一化的通道,该层返回 gamma * (batch - mean(batch)) / sqrt(var(batch) + epsilon) + beta,其中
epsilon 是一个小的常数(可在构造函数参数中配置)。gamma 是一个学习到的缩放因子(初始化为 1),通过向构造函数传递 scale=False 可以禁用它。beta 是一个学习到的偏移因子(初始化为 0),通过向构造函数传递 center=False 可以禁用它。在推理时 (即使用 evaluate() 或 predict(),或在调用层/模型时附带参数 training=False(这是默认值)),该层会使用它在训练期间看到的批次的均值和标准差的移动平均值来对其输出进行归一化。也就是说,它返回 gamma * (batch - self.moving_mean) / sqrt(self.moving_var+epsilon) + beta。
self.moving_mean 和 self.moving_var 是非可训练变量,它们会在该层以训练模式调用时被更新,因此
moving_mean = moving_mean * momentum + mean(batch) * (1 - momentum)moving_var = moving_var * momentum + var(batch) * (1 - momentum)因此,该层将在推理时对输入进行归一化,但前提是在训练过与推理数据统计相似的数据之后。
参数
data_format="channels_first" 的 Conv2D 层之后,使用 axis=1。True,则向归一化张量添加 beta 的偏移量。如果为 False,则忽略 beta。True,则乘以 gamma。如果为 False,则不使用 gamma。当下一层是线性的时,可以禁用此选项,因为缩放将由下一层完成。True,则在分布式训练策略的每个训练步骤中,跨所有设备同步该层的全局批次统计数据(均值和方差)。如果为 False,则每个副本使用其本地批次统计数据。name 和 dtype)。调用参数
training=True:该层将使用当前批次输入的均值和方差对其输入进行归一化。training=False:该层将使用其在训练期间学习到的移动统计数据的均值和方差对其输入进行归一化。inputs 张量可广播的二元张量,其中 True 值表示应计算均值和方差的位置。在训练期间,不会考虑当前输入中被遮蔽的元素来计算均值和方差。任何之前的未被遮蔽的元素值将一直被考虑,直到其动量过期。参考
关于在 BatchNormalization 层上设置 layer.trainable = False
设置 layer.trainable = False 的意思是冻结该层,即其内部状态在训练期间不会改变:其可训练权重不会在 fit() 或 train_on_batch() 期间被更新,并且其状态更新也不会被运行。
通常,这并不一定意味着该层以推理模式运行(这通常由调用层时可以传递的 training 参数控制)。“冻结状态”和“推理模式”是两个不同的概念。
然而,对于 BatchNormalization 层,将该层上的 trainable 设置为 False 意味着该层随后将以推理模式运行(这意味着它将使用移动均值和移动方差来归一化当前批次,而不是使用当前批次的均值和方差)。
请注意:
trainable 会递归地设置所有内部层的 trainable 值。compile() 后更改了 trainable 属性的值,则新值在再次调用 compile() 之前不会对该模型生效。