Batch Normalization,批处理化。
介绍
Batch Normalization 主要是为了解决深度神经网络训练过程中的内部协变量偏移问题(Internal Covariate Shift)。内部协变量偏移是指网络层输入分布的改变,这会导致训练过程变得复杂和不稳定。
在每个 mini-batch 的数据通过网络层时,Batch Normalization 会对每个特征进行归一化,使其均值为 0,方差为 1。然后,通过学习两个参数(缩放因子 γ 和偏移因子 β),对归一化后的数据进行线性变换,以适应网络的需要。
总的公式如下:
应用于4D输入(带有额外通道维度的2D输入的小批量,说白了就是NCHW这样的),具体的看这个文章Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift中的描述。
均值和标准差是根据小批量数据在每个维度上计算的(are calculated per-dimension over the mini-batches),γ和β是大小为C的可学习参数向量(其中C是输入大小)。默认情况下,γ的元素设置为1,β的元素设置为0。标准差是通过有偏估计器计算的,等同于torch.var(input, unbiased=False)。
同样,默认情况下,在训练期间,此层保持其计算的均值和方差的运行估计值,这些估计值随后用于评估期间的归一化。这些运行估计值的默认momentum
为0.1。
如果track_running_stats
设置为False
,则该层不再保持运行估计值,并且在评估时也使用批量统计数据。
Because the Batch Normalization is done over the C dimension, computing statistics on (N, H, W) slices, it’s common terminology to call this Spatial Batch Normalization.
参数:
- num_features (int) – C from an expected input of size (N,C,H,W)
- eps (float) – a value added to the denominator for numerical stability. Default: 1e-5
- momentum (float) – the value used for the running_mean and running_var computation. Can be set to
None
for cumulative moving average (i.e. simple average). Default: 0.1 - affine (bool) – a boolean value that when set to
True
, this module has learnable affine parameters. Default:True
- track_running_stats (bool) – a boolean value that when set to
True
, this module tracks the running mean and variance, and when set toFalse
, this module does not track such statistics, and initializes statistics buffersrunning_mean
andrunning_var
asNone
. When these buffers areNone
, this module always uses batch statistics. in both training and eval modes. Default:True
输入输出维度:
- Input: (N,C,H,W)
- Output: (N,C,H,W) (same shape as input)
示例
# With Learnable Parameters
m = nn.BatchNorm2d(100)
# Without Learnable Parameters
m = nn.BatchNorm2d(100, affine=False)
input = torch.randn(20, 100, 35, 45)
output = m(input)
附录
方差(Variance)和标准差(Standard Deviation)都是描述数据分布离散程度的统计量。它们的区别在于计算方式和单位:
-
方差(Variance):
- 定义:方差是每个数据点与均值的差的平方的平均值。
- 计算公式 Var(X) = \frac{1}{N}\sum_{i=1}^{N}(X_i - \mu)^2
- 单位:方差的单位是数据单位的平方。
-
标准差(Standard Deviation):
- 定义:标准差是方差的平方根。
- 计算公式 SD(X) = \sqrt{Var(X)}
- 单位:标准差的单位与原数据单位相同。