算子合集:BatchNorm2d

Batch Normalization,批处理化。

介绍

Batch Normalization 主要是为了解决深度神经网络训练过程中的内部协变量偏移问题(Internal Covariate Shift)。内部协变量偏移是指网络层输入分布的改变,这会导致训练过程变得复杂和不稳定。

在每个 mini-batch 的数据通过网络层时,Batch Normalization 会对每个特征进行归一化,使其均值为 0,方差为 1。然后,通过学习两个参数(缩放因子 γ 和偏移因子 β),对归一化后的数据进行线性变换,以适应网络的需要。

总的公式如下:

image

应用于4D输入(带有额外通道维度的2D输入的小批量,说白了就是NCHW这样的),具体的看这个文章Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift中的描述。

均值和标准差是根据小批量数据在每个维度上计算的(are calculated per-dimension over the mini-batches),γ和β是大小为C的可学习参数向量(其中C是输入大小)。默认情况下,γ的元素设置为1,β的元素设置为0。标准差是通过有偏估计器计算的,等同于torch.var(input, unbiased=False)。

同样,默认情况下,在训练期间,此层保持其计算的均值和方差的运行估计值,这些估计值随后用于评估期间的归一化。这些运行估计值的默认momentum为0.1。

如果track_running_stats设置为False,则该层不再保持运行估计值,并且在评估时也使用批量统计数据。

Because the Batch Normalization is done over the C dimension, computing statistics on (N, H, W) slices, it’s common terminology to call this Spatial Batch Normalization.

参数:

  • num_features (int) – C from an expected input of size (N,C,H,W)
  • eps (float) – a value added to the denominator for numerical stability. Default: 1e-5
  • momentum (float) – the value used for the running_mean and running_var computation. Can be set to None for cumulative moving average (i.e. simple average). Default: 0.1
  • affine (bool) – a boolean value that when set to True, this module has learnable affine parameters. Default: True
  • track_running_stats (bool) – a boolean value that when set to True, this module tracks the running mean and variance, and when set to False, this module does not track such statistics, and initializes statistics buffers running_mean and running_var as None. When these buffers are None, this module always uses batch statistics. in both training and eval modes. Default: True

输入输出维度:

  • Input: (N,C,H,W)
  • Output: (N,C,H,W) (same shape as input)

示例

# With Learnable Parameters
m = nn.BatchNorm2d(100)
# Without Learnable Parameters
m = nn.BatchNorm2d(100, affine=False)
input = torch.randn(20, 100, 35, 45)
output = m(input)

附录

方差(Variance)和标准差(Standard Deviation)都是描述数据分布离散程度的统计量。它们的区别在于计算方式和单位:

  1. 方差(Variance):

    • 定义:方差是每个数据点与均值的差的平方的平均值。
    • 计算公式 Var(X) = \frac{1}{N}\sum_{i=1}^{N}(X_i - \mu)^2
    • 单位:方差的单位是数据单位的平方。
  2. 标准差(Standard Deviation):

    • 定义:标准差是方差的平方根。
    • 计算公式 SD(X) = \sqrt{Var(X)}
    • 单位:标准差的单位与原数据单位相同。