Replace Manual Normalization with Batch Normalization in Vision AI Models

Author:Murphy  |  View: 23241  |  Time: 2025-03-23 18:35:25

Co-authored with Naresh.

Photo by Kevin Ku on Unsplash

Channel Normalization for image pre-processing

Colour images usually have 3 channels (RGB). It's common for vision AI models to pre-process image pixels and normalize them so that the pixels in a given channel are normalized to have a mean of 0.0 and a variance of 1.0. Normalization is performed per-channel since each channel can have its own statistics. The use of Batch Normalization is a general best practice used in vision models to avoid a phenomenon known as covariate shift.

What is covariate shift?

Covariate shift is a phenomenon that occurs when the distribution of the input features (i.e., covariates) changes between the training and testing phases of a machine learning model. This can lead to degraded performance of the model, as the assumptions made during training may not hold true during testing. Covariate shift can occur due to changes in the data collection process, changes in the population being sampled, or changes in the environment in which the model is being used. In order to address covariate shift, techniques such as domain adaptation and importance weighting can be used to adjust the model's predictions based on the changes in the input distribution.

However, these techniques are fairly complex, and require a deeper understanding of the input data distribution.

How does batch normalization avoid covariate shift?

Batch normalization helps with covariate shift by normalizing the activations of each layer in a neural network. This means that the mean and variance of the activations are maintained at a fixed value, regardless of the distribution of the inputs to that layer. By doing so, batch normalization reduces the effect of covariate shift between the training and testing datasets. Batch normalization can be applied to any domain and scales well to various use cases without any modifications.

More specifically, during training, batch normalization centers and scales the activations of each layer based on the mean and variance of the activations computed over the current batch. This normalizes the activations to have zero mean and unit variance, which helps to stabilize and accelerate the training process. A running mean and variance is tracked during the training phase.

During testing, the mean and variance computed during training are used to normalize the activations. This ensures that the normalization is consistent with the training data and reduces the effect of covariate shift.

By reducing the effect of covariate shift, batch normalization can improve the generalization performance of Neural Networks and make them more robust to changes in the input distribution.


Why to pre-process inputs?

This is done to allow a model to converge faster and generalize well to the input data. Models perform relatively better when the input data distribution is in a consistent and well-specified range. Specifically, it has the following benefits:

  1. To prevent overflow or underflow: Machine learning models often involve mathematical operations such as addition, multiplication, and exponentiation. If the input values are too large or too small, the operations can result in overflow or underflow, leading to inaccurate or undefined results. For example, adding a small floating point (fp32) number to a large one can end up ignoring the small number.
  2. To ensure efficient learning: Neural networks often use a backpropagation algorithm to update the weights in the network. This algorithm relies on calculating gradients, which can become very small or very large if the input values are not normalized, making the learning process slower and less efficient.
  3. To improve generalization: When a model is trained on a specific range of input values, it may not perform well on inputs outside of that range. Normalizing the inputs can help the model generalize to new and unseen data.

Batch normalization can help address some of these issues by normalizing the inputs to each layer of the network during training. This can improve the stability of the model and make it less sensitive to the scale of the inputs. However, Batch normalization is not always a drop-in replacement for input normalization done manually with hard-coded constants.


What is Batch normalization?

The topic of what Batch normalization is and how it helps models has been covered extensively in numerous articles, so we shall link to the ones that provide the most detailed insights and allow the reader to develop an intuition for the operation. We also provide some links that compare batch normalization with other normalization techniques.


How is input normalization typically done?

Typically, the person training the model is responsible for computing the per-channel statistics (mean and variance) for the entire training dataset, and normalizing the input before training the vision AI model. This normalization is also expected to be performed during inference.

The code for this pre-processing may look like this when using torchvision transforms.

transforms = torch.nn.Sequential(
    transforms.CenterCrop(10),
    transforms.Normalize(
      # Channel means
      #    R,     G,     B
      (0.485, 0.456, 0.406),
      # Channel standard deviation
      (0.229, 0.224, 0.225),
    ),
)

You can read more about the Normalize transform here.

You'll notice that the means and standard deviations are pre-computed and then hard-coded into the pre-processing pipeline. There's a lot of discussion online about how to do this both correctly and efficiently. For example,

  1. Computing the mean and std of dataset
  2. Kaggle Notebook: Computing Dataset Mean and STD (using PyTorch)
  3. How to calculate mean and standard deviation of images in PyTorch

How to make this less painful?

Now that we know how painful this process is, we'll see a neat trick to make this less painful for you!

Simply stick a BatchNorm2d layer as the first layer of your vision AI model, and drop the Normalize transform from your pre-processing steps!

By using batch normalization as the first layer of the model, the input data will be normalized automatically during the training process, and you won't need to manually normalize the image pixels. This can save you some coding time and reduce the chances of introducing errors in the normalization process.


The rough equivalence of manual normalization and Batch normalization

Here we'll see some code that can convince us about the rough equivalence of the 2 approaches.

We'll take 1000 batches of a randomly generated 1×1 image with 3 channels, and see if the manually computed mean and variance are similar to the ones computed using Pytorch's BatchNorm2d layer.

torch.manual_seed(21)
num_channels = 3

# Example tensor so that we can use randn_like() below.
y = torch.randn(20, num_channels, 1, 1)

model = nn.BatchNorm2d(num_channels)

# nb is a dict containing the buffers (non-trainable parameters)
# of the BatchNorm2d layer. Since these are non-trainable
# parameters, we don't need to run a backward pass to update
# these values. They will be updated during the forward pass itself.
nb = dict(model.named_buffers())
print(f"Buffers in BatchNorm2d: {nb.keys()}n")

stacked = torch.tensor([]).reshape(0, num_channels, 1, 1)

for i in range(2000):
    x = torch.randn_like(y)
    y_hat = model(x)
    # Save all the input tensor into 'stacked' so that
    # we can compute the mean and variance later.
    stacked = torch.cat([stacked, x], dim=0)
# end for

print(f"Shape of stackend tensor: {stacked.shape}n")
smean = stacked.mean(dim=(0, 2, 3))
svar = stacked.var(dim=(0, 2, 3))
print(f"Manually Computed:")
print(f"------------------")
print(f"Mean: {smean}nVariance: {svar}n")
print(f"Computed by BatchNorm2d:")
print(f"------------------------")
rm, rv = nb['running_mean'], nb['running_var']
print(f"Mean: {rm}nVariance: {rv}n")
print(f"Mean Absolute Differences:")
print(f"--------------------------")
print(f"Mean: {(smean-rm).abs().mean():.4f}, Variance: {(svar-rv).abs().mean():.4f}")

You can see that output of the code cell below.

Buffers in BatchNorm2d: dict_keys(['running_mean', 'running_var', 'num_batches_tracked'])

Shape of stackend tensor: torch.Size([40000, 3, 1, 1])

Manually Computed:
------------------
Mean: tensor([0.0039, 0.0015, 0.0095])
Variance: tensor([1.0029, 1.0026, 0.9947])

Computed by BatchNorm2d:
------------------------
Mean: tensor([-0.0628,  0.0649,  0.0600])
Variance: tensor([1.0812, 1.0318, 1.0721])

Mean Absolute Differences:
--------------------------
Mean: 0.0602, Variance: 0.0616

We started with a random tensor initialized using torch.randn_like(), so we expect that over a sufficiently large (40k) number of samples, the mean and variance will tend to 0.0 and 1.0 respectively, since that's what we expect torch.randn_like() to generate.

We see that the difference between the manually computed mean and variance over the entire input and the mean and variance computed using BatchNorm2d's rolling average based method is close enough for all practical purposes. We can see that the means computed using BatchNorm2d are consistently higher or lower (by up to 40x) than those computed manually. However, in practical terms, this should not matter.


Caveats

There are definitely pros and cons of using this Batch normalization substitution in place of manual normalization, and this article wouldn't be complete without a detailed comparison between the 2 in terms of when each one may or may not be applicable.

Transfer learning

When using transfer learning, it is often recommended to keep the normalization method used in the pre-trained model to avoid introducing unnecessary changes. In this case, it may not be appropriate to replace manual normalization with batch normalization.

For example, here's what the torchvision page has to say on this subject.

"Before using the pre-trained models, one must preprocess the image (resize with right resolution/interpolation, apply inference transforms, rescale the values etc). There is no standard way to do this as it depends on how a given model was trained. It can vary across model families, variants or even weight versions. Using the correct preprocessing method is critical and failing to do so may lead to decreased accuracy or incorrect outputs."

Training Efficiency

When considering training efficiency, it may be beneficial to pre-compute the mean and variance of your dataset and hard-code it as a pre-training normalization step. This prevents the repeated computation of these statistics during training.

Note that using either method uses an equal amount of compute during inference since you'll normalize the inputs using the computed mean and variance either before feeding the inputs into the model (manually computed) or as the first step of the model (using Batch normalization).

Human Efficiency

Based on what we've seen above, it's much easier for a human to stick a Batch normalization layer instead of manually computing statistics.

Data Augmentation

When augmenting inputs before feeding them into the model, one must be careful to apply the normalization after all the augmentation and other pre-processing steps are done to avoid computing incorrect statistics. For example, if you're using a ColorJitter transform, then it will change the computed statistics meaningfully.

This leads us to another interesting question "When should the mean and variance of the dataset be computed when using data augmentation?"

This is tricky since manually computing the mean and variance accurately across augmented inputs would require you to know a priori which images will be augmented with which transforms, and then requires you to compute the statistics and apply augmentations during model training in a consistent manner. This is somewhat hard to do in general since augmentations are applied to input images at random. In addition, the same image is augmented differently at every training epoch to prevent the model from overfitting on the training dataset. Hence, in this case, using Batch normalization would result in better model accuracy as well, since it's computing the mean and variance on the augmented image and not on the original un-augmented images.

Mini-batch distribution

Since Batch normalization computes the mean and variance of the mini-batch for normalizing values during training, it is important to randomize the input data to ensure the mean and variance of the mini-batch are somewhat representative of the entire training dataset. If your mini-batches are small or biased, then you should consider un-biasing them or using manual normalization.

This problem doesn't exist at test or inference time.


Conclusion

Normalization is an important pre-processing step in any vision AI pipeline. Computing it correctly and efficiently via manual steps can be tedious and error prone. Using Batch normalization as the first layer of your vision AI model is a viable substitute for manual normalization in many scenarios.

Tags: Batch Normalization Deep Learning Neural Networks Pytorch Vision

Comment