BatchNormalization
类keras.layers.BatchNormalization(
axis=-1,
momentum=0.99,
epsilon=0.001,
center=True,
scale=True,
beta_initializer="zeros",
gamma_initializer="ones",
moving_mean_initializer="zeros",
moving_variance_initializer="ones",
beta_regularizer=None,
gamma_regularizer=None,
beta_constraint=None,
gamma_constraint=None,
synchronized=False,
**kwargs
)
对输入进行归一化的层。
批归一化(Batch normalization)应用一种变换,使输出的均值接近 0,输出的标准差接近 1。
重要的是,批归一化在训练和推断期间的工作方式不同。
训练期间(即使用 fit()
或以参数 training=True
调用层/模型时),该层使用当前输入批次的均值和标准差对其输出进行归一化。也就是说,对于每个被归一化的通道,该层返回 gamma * (batch - mean(batch)) / sqrt(var(batch) + epsilon) + beta
,其中
epsilon
是一个小的常数(可在构造函数参数中配置)gamma
是一个学习到的缩放因子(初始化为 1),可以通过向构造函数传递 scale=False
来禁用。beta
是一个学习到的偏移因子(初始化为 0),可以通过向构造函数传递 center=False
来禁用。推断期间(即使用 evaluate()
或 predict()
或以参数 training=False
调用层/模型时(这是默认值)),该层使用在训练期间看到的批次的均值和标准差的移动平均值对其输出进行归一化。也就是说,它返回 gamma * (batch - self.moving_mean) / sqrt(self.moving_var+epsilon) + beta
。
self.moving_mean
和 self.moving_var
是非训练变量,每次在训练模式下调用该层时都会更新,更新方式如下:
moving_mean = moving_mean * momentum + mean(batch) * (1 - momentum)
moving_var = moving_var * momentum + var(batch) * (1 - momentum)
因此,只有在对与推断数据具有相似统计特征的数据进行训练之后,该层才会在推断期间对其输入进行归一化。
参数
data_format="channels_first"
的 Conv2D
层之后,使用 axis=1
。True
,则向归一化张量添加 beta
偏移。如果为 False
,则忽略 beta
。True
,则乘以 gamma
。如果为 False
,则不使用 gamma
。当下一层是线性层时,可以禁用此项,因为缩放将由下一层完成。True
,在分布式训练策略中,在每个训练步骤将该层的全局批统计量(均值和方差)跨所有设备同步。如果为 False
,每个副本使用其自己的本地批统计量。name
和 dtype
)。调用参数
training=True
: 该层将使用当前输入批次的均值和方差对其输入进行归一化。training=False
: 该层将使用其在训练期间学习到的移动统计量(均值和方差)对其输入进行归一化。inputs
张量形状可广播的二进制张量,其中 True
值表示应计算均值和方差的位置。训练期间,当前输入中的被掩码元素不计入均值和方差的计算。任何先前未被掩码的元素值将在其动量过期前被计入。参考
关于在 BatchNormalization
层上设置 layer.trainable = False
设置 layer.trainable = False
的含义是冻结该层,即其内部状态在训练期间不会改变:其可训练权重不会在 fit()
或 train_on_batch()
期间更新,其状态更新也不会运行。
通常,这并不一定意味着该层在推断模式下运行(通常由调用层时可传递的 training
参数控制)。“冻结状态”和“推断模式”是两个独立的概念。
然而,对于 BatchNormalization
层,在该层上设置 trainable = False
意味着该层随后将在推断模式下运行(这意味着它将使用移动均值和移动方差来归一化当前批次,而不是使用当前批次的均值和方差)。
注意
trainable
将递归地设置所有内部层的 trainable
值。compile()
后更改 trainable
属性的值,新值对该模型直到再次调用 compile()
后才生效。