Understanding and comparing Batch Norm with all different variations
In this article we explain what is batch normalization (BN) and how all it variations works. We try to explain the difference conceptually and mathematically. At the end you will have a deeper understanding on how normalization it’s done in modern neural network.
Introduction
Batch Normalization is a widely used method in deep learning to make training faster and more stable.
The main idea is to normalize each values of each internal layer of a deep neural network.
To understand better have a look at the image below
Without BN layer the outputs values of a layer are not normalized.
To normalize values BN calculates normalization parameters using all data in the batch size per channels.
The normalization formula is the following.
- x values
- 𝜇ᵢ mean
- 𝜎ᵢ standard deviation
xᵢ ← (xᵢ — 𝜇ᵢ) / √(𝜎ᵢ² + 𝜀)
Problem
Batch Normalization has some problems:
- Not working well for small batch size (1,2,4) during training
- Can cause different accuracy between training and inference. To partially avoid this on inference step, deep learning frameworks(e.g Pytorch) calculate the statistics for normalization by averaging all means and standard deviations obtained during training.
To overcome these problems other methods were proposed.
Example
The example section is used to explain how all different methods work mathematically based on this simple batch.
Suppose we have 3 tensor a,b,c with the following values. a,b,c tensors could be a simple batch size of your data.
All these tensors have dimension (4,1,2)-> (C,H,W). Batch is (3,4,1,2)
a = [ [[2, 3]], [[5, 7]], [[11, 13]], [[17, 19]] ]
b = [ [[0, 1]], [[1, 2]], [[3, 5]], [[8, 13]] ]
c = [ [[1, 2]], [[3, 4]], [[5, 6]], [[7, 8]] ]
Methods Comparison
Batch Normalization in detail
To understand better Batch Normalization and how all other variations vary from this method, we need to go a little bit deeper.
Suppose you have a tensor of (m,C,H,W) where m is the batch size, The normalization parameters are calculated as following:
we have learnable parameters γc and βc for each channel, such that
γc and βc are used to avoid normalization only between [0,1] and have all the internal layer normalize to a common range.
With γc and βc the optimizer could optimize the normalization values for each layer independently from each other,
To visualize what batch norm does, have a look at the image below.
As you can see batch norm layers calculate the value of normalization parameters for each channel by using all data in the batch size (N).
To have a deeper introduction to batch norm have a look at this great article [4]
Example
Recall the tensors a,b,c above.
- Batch normalization will compute mean and std deviation for each channel using all values of each data(tensor) in the batch size
𝜇ᵢ = mean(2, 3, 0, 1, 1, 2) = 1.5
𝜎ᵢ² = var(2, 3, 0, 1, 1, 2) = 0.917
If you have a look at the values are taken from the first channel of all tensors in the batch.
2. Calculate the normalization values
#Normalization formula for first channel
aᵢ ← (2–1.5) / √(0.917 + 0.00001) = 0.522a ← [ [[0.522, 1.567]], [[0.676, 1.690]], [[1.071, 1.630]], [[1.066, 1.492]] ]
This process is repeated for all channels.
Momentum Batch Normalization (MBN)
Is a new technique that is the same as Batch Normalization, but introduce a new parameter the momentum to control the effect of normalization.
This can reduce the dependency of small batch size. MBN [1] maintains the same accuracy as Batch Norm for big batch size (>8), while improving for small batch size(2,4)
Pytorch
In pytorch you can control the momentum of batch norm
torch.nn.BatchNorm2d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, device=None, dtype=None)
By default Pytorch apply Momentum Batch Norm
Layer Normalization(LN)
Layer normalization is similar to batch normalization, but to calculate mean and standard deviations use only values from single tensor and not use all tensors in a batch size.
Layer normalization use all values from all the channels of a single tensor for normalization.
As you can see in the image above, layer normalization use all the channels of single tensor to normalize the values. It is not using all tensors in the batch size.
And we have the same 2 parameters γc and βc similar to batch norm
Example
For layer normalization
- Calculate mean and std deviation by using all value for each channel of single tensor.
𝜇ᵢ = mean(2, 3, 5, 7, 11, 13, 17, 19) = 9.625
𝜎ᵢ² = var(2, 3, 5, 7, 11, 13, 17, 19) = 35.734
2. Apply the normalization formula
a ← [ [[-1.276, -1.108]], [[-0.773, -0.439]], [[0.230, 0.565]], [[1.234, 1.568]] ]
Group Normalization(GN)
Normalize the data by dividing it in groups(G)
Group Normalization(GN) is a normalization layer that divides channels into groups and normalizes the values within each group.
GN does not exploit the batch dimension, and its computation is independent of batch sizes.
GN outperform Batch normalization for small batch size (2,4), but not for bigger batch size (64,128,..)
The parameter G is an hyperparameter used to create group based on the number of channels.
For example if Channels(C)=4 and Group(G)=2 we have 2 channels used to calculate normalization parameters.
Example
From the tensor a,b,c we will use G=2. If we have 4 Channels we use only the first 2 channels (C/G) for creating the group to calculate normalization parameters.
We take only one tensor from batch and we consider only values in group of 2 channels. For example for tensor a the first group is composed by first 2 channels.
𝜇ᵢ = mean(2, 3, 5, 7) = 4.25
𝜎ᵢ² = var(2, 3, 5, 7) = 3.687
and then we normalize the values based on mean and standard deviation
aᵢ ← (2–4.25) / √(3.687 + 0.00001) = -1.172
we repeat the same process for other channels for every tensor in the batch.
Weight Normalization (WN)
Weight normalization is similar to batch normalization, but avoid the dependence from the number of data in the batch.
parameterize each vector W(weights) with vector v in order to keep value normalize.
Not so used in practice.
Weight Standardization (WS)
Usually used with Group Normalization to outperform batch normalization for small batch size (1,2).
During training weights parameters are updated by optimizer and could loose normalization even if outputs value are normalized because of GN.
WS keeps weights value normalized.
This method differ from previous ones because normalize parameters of neural network not output values.
If you combine WS and GN
Result from applying WS + GN from [9]
For example in Convolutional Neural Network, parameters are called filters. Filters are used to do convolution all over an image to obtain a feature map. Filters usually have 4 dimensions
(Cout,Cin,height, width)
- Channel Out
- Channel In
- Height
- Width
WS will normalize the value inside the filter input corresponding to an output feature map. From [9]
To understand how it works have a look at the image below taken from [10]. The image is a visualization for how a CNN works.
Cin= 3 Cout= 2
W0 e W1 compose a matrix of 2x3x3x3 that correspond to (Cout,Cin,Height,Width)
WS normalize W0 e W1 by using only values from W0 and W1 respectively.
Short Comparison
Batch Normalization [6]
- (+) Stable if the batch size is large [6]
- (+) Robust (in train) to the scale & shift of input data [6]
- (+) Robust to the scale of weight vector [6]
- (+) Scale of update decreases while training [6]
- (-) Not good for online learning [6]
- (-) Not good for RNN, LSTM [6]
- (-) Different calculation between train and test [6]
Layer Normalization[6]
- (+) Effective to small mini batch RNN [6]
- (+) Robust to the scale of input [6]
- (+) Robust to the scale and shift of weight matrix [6]
- (+) Scale of update decreases while training [6]
- (-) Might be not good for CNN (Batch Norm is better in some cases) [6]
- (-) LayerNorm doesn’t have the special regularization effects that BatchNorm has from normalizing across data points. [7]
Group Normalization
- (+) Better on small batch size respect to batch norm [6]. This is True if you combine GN with WS [8]
- (-) Perform worse than BN for larger batch size
Weight Normalization
- (+) Smaller calculation cost on CNN [6]
- (+) Well-considered about weight initialization [6]
- (+) Implementation is easy [6]
- (+) Robust to the scale of weight vector [6]
- (-) Compared with the others, might be unstable on training [6]
- (-) High dependence to input data [6]
Conclusion
Batch normalization is still a widely used method for training deep neural networks.
What works best?
My suggestion is to have a deeper understanding of what all methods do and try. There is not the perfect answer
Reference
[1] https://www.ecva.net/papers/eccv_2020/papers_ECCV/papers/123570222.pdf
[2] https://leimao.github.io/blog/Batch-Normalization/
[3] https://leimao.github.io/blog/Layer-Normalization/
[4] https://towardsdatascience.com/batch-normalization-in-3-levels-of-understanding-14c2da90a338
[5] https://towardsdatascience.com/what-is-group-normalization-45fe27307be7
[7]https://tungmphung.com/deep-learning-normalization-methods/
[9] https://arxiv.org/pdf/1903.10520.pdf