NT-Xent (Normalized Temperature-Scaled Cross-Entropy) Loss Explained and Implemented in PyTorch

Author:Murphy  |  View: 21084  |  Time: 2025-03-23 18:24:01

Co-authored with Naresh Singh.

Formula for NT-Xent loss. Source: Papers with code (CC-BY-SA)

Introduction

Recent advances in self-supervised learning and contrastive learning have excited researchers and practitioners in Machine Learning (ML) to explore this space with renewed interest.

In particular, the SimCLR paper that presents a simple framework for contrastive learning of visual representations has gained a lot of attention in the self-supervised and contrastive learning space.

The central idea behind the paper is very simple – allow the model to learn if a pair of images were derived from the same or different initial image.

Figure 1: The high-level idea behind SimCLR. Source: SimCLR paper

The SimCLR approach encodes each input image i as a feature vector zi. There are 2 cases to consider:

  1. Positive Pairs: The same image is augmented using a different set of augmentations, and the resulting feature vectors zi and zj are compared. These feature vectors are forced to be similar by the Loss Function.
  2. Negative Pairs: Different images are augmented using a different set of augmentations, and the resulting feature vectors zi and zk are compared. These feature vectors are forced to be dissimilar by the loss function.

The rest of this article will focus on explaining and understanding this loss function, and its efficient implementation using Pytorch.


The NT-Xent Loss

At a high level, the contrastive learning model is fed 2N images, originating from N underlying images. Each of the N underlying images is augmented using a random set of image augmentations to produce 2 augmented images. This is how we end up with 2N images in a single train batch fed to the model.

Figure 2: A batch of 6 images in a single training batch for contrastive learning. The number below each image is the index of that image in the input batch when fed into a contrastive learning model. Image Source: Oxford Visual Geometry Group (CC-SA).

In the following sections, we will dive deep into the following aspects of the NT-Xent loss.

  1. The effect of temperature on SoftMax and Sigmoid
  2. A simple and intuitive interpretation of the NT-Xent loss
  3. A step-by-step implementation of NT-Xent in PyTorch
  4. Motivating the need for a multi-label loss function (NT-BXent)
  5. A step-by-step implementation of NT-BXent in PyTorch

All the code for steps 2–5 can be found in [this notebook](https://github.com/dhruvbird/ml-notebooks/blob/main/nt-xent-loss/SoftMax%20and%20Sigmoid%20with%20temperature.ipynb). The code for step-1 can be found in this notebook.


The effect of temperature on SoftMax and Sigmoid

To understand all the moving parts of the contrastive loss function we'll be studying in this article, we need to first understand the effect of temperature on the SoftMax and Sigmoid activation functions.

Typically, temperature scaling is applied to the input to SoftMax or Sigmoid to either smooth out or accentuate the output of those activation functions. The input logits are divided by the temperature before passing into the activation functions. You can find all the code for this section in this notebook.

SoftMax: For SoftMax, a high temperature reduces the variance in the output distribution which results in softening of the labels. A low temperature increases the variance in the output distribution and makes the maximum value stand out over the other values. See the charts below for the effect of temperature on SoftMax when fed with the input tensor [0.1081, 0.4376, 0.7697, 0.1929, 0.3626, 2.8451].

Figure 3: Effect of temperature on SoftMax. Source: Author(s)

Sigmoid: For Sigmoid, a high-temperature results in an output distribution that is pulled towards 0.0, whereas a low temperature stretches the inputs to higher values, stretching the outputs to be closer to either 0.0 or 1.0 depending on the unsigned magnitude of the input.

Figure 4: Effect of temperature on Sigmoid. Source: Author(s)

Now that we understand the effect of various temperature values on the SoftMax and Sigmoid functions, let's see how this applies to our understanding of the NT-Xent loss.


Interpreting the NT-Xent loss

The NT-Xent loss is understood by understanding the individual terms in the name of this loss.

  1. Normalized: Cosine similarity produces a normalized score in the range [-1.0 to +1.0]
  2. Temperature-scaled: The all-pairs cosine similarity is scaled by a temperature before computing the cross-entropy loss
  3. Cross-entropy loss: The underlying loss is a multi-class (single-label) cross-entropy loss

As mentioned above, we assume that for a batch of size 2N, the feature vectors at the following indices represent positive pairs (0, 1), (2, 3), (4, 5), (6, 7), … and the rest of the combinations represent negative pairs. This is an important factor to keep in mind throughout the interpretation of the NT-Xent loss as it relates to SimCLR.

Now that we understand what the terms mean in the context of the NT-Xent loss, let's take a look at the mechanical steps needed to compute the NT-Xent loss on a batch of feature vectors.

  1. The all-pairs Cosine Similarity score is computed for each of the 2N vectors produced by the SimCLR model. This results in (2N)² similarity scores represented as a 2N x 2N matrix
  2. Comparison results between the same value (i, i) are discarded (since a distribution is perfectly similar to itself and can't possibly allow the model to learn anything useful)
  3. Each value (cosine similarity) is scaled by a temperature parameter

    Tags: Binary Cross Entropy Cross Entropy Loss Function Pytorch Self Supervised Learning

Comment