How to Implement State-of-the-Art Masked AutoEncoders (MAE)

Author:Murphy  |  View: 23188  |  Time: 2025-03-23 11:27:15

Hi everyone! For those who do not know me yet, my name is Francois, I am a Research Scientist at Meta. I have a passion for explaining advanced AI concepts and making them more accessible.

Today, I'm excited to delve into one of the most significant breakthroughs in Computer Vision post-Vision Transformers: Masked Autoencoders (MAE). This article serves as the practical implementation companion to my previous post: The Ultimate Guide to Masked Autoencoders (MAE)

For the following tutorial, we will use the code available on this github repository:

GitHub – FrancoisPorcher/awesome-ai-tutorials: The best collection of AI tutorials to make you a…

Here is a brief reminder of how it works:

Image from article MAE are Scalable Learners

Here's how the methodology works:

  1. The image is split into patches.
  2. A subset of these patches is randomly masked.
  3. Only the visible patches are fed into the encoder (this is crucial).
  4. The decoder receives the compressed representation from the encoder and attempts to reconstruct the entire image using both visible and masked patches.
  5. The loss is computed only on the masked patches.

Let's dive into the code!

Imports

  • einops: used for its "repeat" function
  • architectures.vit: architecture of a standard ViT Transformer, I am using the one provided in How to train a ViT?
import torch
from torch import nn
import torch.nn.functional as F
from einops import repeat

from architectures.vit import Transformer

Setting up the MAE Class:

class MAE(nn.Module):
    def __init__(
        self,
        *,
        encoder,
        decoder_dim,
        masking_ratio=0.75,
        decoder_depth=1,
        decoder_heads=8,
        decoder_dim_head=64
    ):
        super().__init__()
        # Ensure the masking ratio is valid
        assert 0 < masking_ratio < 1, 'masking ratio must be between 0 and 1'
        self.masking_ratio = masking_ratio

We define a class MAE that inherits from PyTorch's nn.Module.

  • encoder: Our Vision Transformer model.
  • decoder_dim: The dimension of the decoder's embedding space (512 for example).
  • masking_ratio: The fraction of patches to mask (article found that 75% is optimal).
  • Other decoder configurations like depth, heads, and head dimensions which are standards for Transformers.
  • We assert that the masking ratio is between 0 and 1.

Patches:

        # Save the encoder (a Vision Transformer to be trained)
        self.encoder = encoder

        # Extract the number of patches and the encoder's dimensionality from the positional embeddings
        num_patches, encoder_dim = encoder.pos_embedding.shape[-2:]

        # Separate the patch embedding layers from the encoder
        # The first layer converts the image into patches
        self.to_patch = encoder.to_patch_embedding[0]
        # The remaining layers embed the patches
        self.patch_to_emb = nn.Sequential(*encoder.to_patch_embedding[1:])

What's Going On?

• We store the encoder and extract essential information like the number of patches and the encoder's output dimension.

We separate the patch embedding process:

  • self.to_patch: This layer splits the image into smaller patches.
  • self.patch_to_emb: This embeds each patch into a vector space.
# Determine the dimensionality of the pixel values per patch
pixel_values_per_patch = encoder.to_patch_embedding[2].weight.shape[-1]
  • We calculate how many pixel values are in each patch, which we'll need later.

Setting up the decoder

  • self.enc_to_dec: If the encoder and decoder have different dimensions, we map them accordingly. Usually the encoder is large and has a higher dimension (for example 1024), while the decoder can be more shallow and with a smaller dimension (for example 512), but we need an adapter to map the encoder's dimension to the decoder's dimension
  • self.mask_token: A learnable token that represents the masked patches for the decoder. This is the token that the decoder sees when a patch is masked.
  • We initialize the decoder transformer and other layers needed for reconstruction.
self.decoder = Transformer(
    dim=decoder_dim,
    depth=decoder_depth,
    heads=decoder_heads,
    dim_head=decoder_dim_head,
    mlp_dim_ratio=4
)
# Positional embeddings for the decoder tokens
self.decoder_pos_emb = nn.Embedding(num_patches, decoder_dim)
# Linear layer to reconstruct pixel values from decoder outputs
self.to_pixels = nn.Linear(decoder_dim, pixel_values_per_patch)

So Far, your MAE Class should be initialized like this:

class MAE(nn.Module):
    def __init__(
        self,
        *,
        encoder,
        decoder_dim,
        masking_ratio=0.75,
        decoder_depth=1,
        decoder_heads=8,
        decoder_dim_head=64
    ):
        super().__init__()
        # Ensure the masking ratio is valid
        assert 0 < masking_ratio < 1, 'masking ratio must be between 0 and 1'
        self.masking_ratio = masking_ratio

        # Save the encoder (a Vision Transformer to be trained)
        self.encoder = encoder

        # Extract the number of patches and the encoder's dimensionality from the positional embeddings
        num_patches, encoder_dim = encoder.pos_embedding.shape[-2:]

        # Separate the patch embedding layers from the encoder
        # The first layer converts the image into patches
        self.to_patch = encoder.to_patch_embedding[0]
        # The remaining layers embed the patches
        self.patch_to_emb = nn.Sequential(*encoder.to_patch_embedding[1:])

        # Determine the dimensionality of the pixel values per patch
        pixel_values_per_patch = encoder.to_patch_embedding[2].weight.shape[-1]

        # Set up decoder parameters
        self.decoder_dim = decoder_dim
        # Map encoder dimensions to decoder dimensions if they differ
        self.enc_to_dec = (
            nn.Linear(encoder_dim, decoder_dim)
            if encoder_dim != decoder_dim
            else nn.Identity()
        )
        # Learnable mask token for masked patches
        self.mask_token = nn.Parameter(torch.randn(decoder_dim))
        # Define the decoder transformer
        self.decoder = Transformer(
            dim=decoder_dim,
            depth=decoder_depth,
            heads=decoder_heads,
            dim_head=decoder_dim_head,
            mlp_dim_ratio=4
        )
        # Positional embeddings for the decoder tokens
        self.decoder_pos_emb = nn.Embedding(num_patches, decoder_dim)
        # Linear layer to reconstruct pixel values from decoder outputs
        self.to_pixels = nn.Linear(decoder_dim, pixel_values_per_patch)

Great! Now let's see how we use these different parts in the forward pass, it's a bit like assembling like a puzzle.

Forward Pass

Let's walk through the forward function, which defines how our model processes input data.

def forward(self, img):
    device = img.device

    # Convert the input image into patches
    patches = self.to_patch(img)  # Shape: (batch_size, num_patches, patch_size)
    batch_size, num_patches, *_ = patches.shape

    # Embed the patches using the encoder's patch embedding layers
    tokens = self.patch_to_emb(patches)  # Shape: (batch_size, num_patches, encoder_dim)

The beginning is very standard, we just have to decompose the "patchifying the image" operation with the "projection to tokens" operation, because we use the original patches as a ground truth to compute the loss.

  • The forward method takes an image tensor img as input.
  • We get the device (CPU or GPU) where the tensor is located.
  • We split the image into patches.
  • We obtain the batch size and number of patches.
  • Each patch is embedded into a vector.

Positional Encoding:

# Add positional embeddings to the tokens
if self.encoder.pool == "cls":
    # If using CLS token, skip the first positional embedding
    tokens += self.encoder.pos_embedding[:, 1 : num_patches + 1]
elif self.encoder.pool == "mean":
    # If using mean pooling, use all positional embeddings
    tokens += self.encoder.pos_embedding.to(device, dtype=tokens.dtype)
  • We add positional information to each token so the model knows where each patch came from. If there is an additional CLS token, we need to skip it because it is not part of the image.

Masking and Encoding

Now we come to the most interesting part, masking the image.

# Determine the number of patches to mask
num_masked = int(self.masking_ratio * num_patches)

# Generate random indices for masking
rand_indices = torch.rand(batch_size, num_patches, device=device).argsort(dim=-1)
masked_indices = rand_indices[:, :num_masked]
unmasked_indices = rand_indices[:, num_masked:]
  • We calculate how many patches we'll mask based on our masking ratio.
  • We generate a random permutation for each sequence of patch.
  • We define the masked_indices and unmasked_indices accordingly
# Select the tokens corresponding to unmasked patches
batch_range = torch.arange(batch_size, device=device)[:, None]
tokens = tokens[batch_range, unmasked_indices]

# Select the original patches that are masked (for reconstruction loss)
masked_patches = patches[batch_range, masked_indices]

# Encode the unmasked tokens using the encoder's transformer
encoded_tokens = self.encoder.transformer(tokens)
  • We select the masked_patches with the corresponding masked_indices we have just defined.
  • We keep only the tokens for the unmasked patches for encoding.

Decoding

Now let's jump into the most exciting but also hardest part, decoding!

# Map encoded tokens to decoder dimensions if necessary
decoder_tokens = self.enc_to_dec(encoded_tokens)

# Add positional embeddings to the decoder tokens of unmasked patches
unmasked_decoder_tokens = decoder_tokens + self.decoder_pos_emb(unmasked_indices)

# Create mask tokens for the masked patches and add positional embeddings
mask_tokens = repeat(self.mask_token, 'd -> b n d', b=batch_size, n=num_masked)
mask_tokens = mask_tokens + self.decoder_pos_emb(masked_indices)

# Initialize the full sequence of decoder tokens
decoder_sequence = torch.zeros(
  batch_size, num_patches, self.decoder_dim, device=device
)
# Place unmasked decoder tokens and mask tokens in their original positions
decoder_sequence[batch_range, unmasked_indices] = unmasked_decoder_tokens
decoder_sequence[batch_range, masked_indices] = mask_tokens

# Decode the full sequence
decoded_tokens = self.decoder(decoder_sequence)

# Extract the decoded tokens corresponding to the masked patches
masked_decoded_tokens = decoded_tokens[batch_range, masked_indices]
  • We adjust the encoded tokens to match the decoder's expected input size with self.enc_to_dec
  • We add positional embeddings to the decoder tokens.
  • For the masked positions, we use the mask token and add positional embeddings.
  • We reconstruct the full sequence by placing the unmasked and masked tokens back into their original positions.
  • We pass the full sequence through the decoder.
  • We extract only the decoded tokens that correspond to the masked patches.
# Reconstruct the pixel values from the masked decoded tokens
pred_pixel_values = self.to_pixels(masked_decoded_tokens)

# Compute the reconstruction loss (mean squared error)
recon_loss = F.mse_loss(pred_pixel_values, masked_patches)
return recon_loss
  • We attempt to reconstruct the original pixel values of the masked patches.
  • We calculate the L2 loss by comparing the reconstructed patches to the original masked patches.

Congratulations, you've made it!

Thanks for reading! Before you go:

For more awesome tutorials, check my compilation of AI tutorials on Github

GitHub – FrancoisPorcher/awesome-ai-tutorials: The best collection of AI tutorials to make you a…

You should get my articles in your inbox. Subscribe here.

If you want to have access to premium articles on Medium, you only need a membership for $5 a month. If you sign up with my link, you support me with a part of your fee without additional costs.


If you found this article insightful and beneficial, please consider following me and leaving a clap for more in-depth content! Your support helps me continue producing content that aids our collective understanding.

References

Tags: Computer Vision Data Science Deep Learning Technology Transformers

Comment