How to Implement State-of-the-Art Masked AutoEncoders (MAE)
Hi everyone! For those who do not know me yet, my name is Francois, I am a Research Scientist at Meta. I have a passion for explaining advanced AI concepts and making them more accessible.
Today, I'm excited to delve into one of the most significant breakthroughs in Computer Vision post-Vision Transformers: Masked Autoencoders (MAE). This article serves as the practical implementation companion to my previous post: The Ultimate Guide to Masked Autoencoders (MAE)
For the following tutorial, we will use the code available on this github repository:
GitHub – FrancoisPorcher/awesome-ai-tutorials: The best collection of AI tutorials to make you a…
Here is a brief reminder of how it works:

Here's how the methodology works:
- The image is split into patches.
- A subset of these patches is randomly masked.
- Only the visible patches are fed into the encoder (this is crucial).
- The decoder receives the compressed representation from the encoder and attempts to reconstruct the entire image using both visible and masked patches.
- The loss is computed only on the masked patches.
Let's dive into the code!
Imports
einops
: used for its "repeat" functionarchitectures.vit
: architecture of a standard ViT Transformer, I am using the one provided in How to train a ViT?
import torch
from torch import nn
import torch.nn.functional as F
from einops import repeat
from architectures.vit import Transformer
Setting up the MAE Class:
class MAE(nn.Module):
def __init__(
self,
*,
encoder,
decoder_dim,
masking_ratio=0.75,
decoder_depth=1,
decoder_heads=8,
decoder_dim_head=64
):
super().__init__()
# Ensure the masking ratio is valid
assert 0 < masking_ratio < 1, 'masking ratio must be between 0 and 1'
self.masking_ratio = masking_ratio
We define a class MAE that inherits from PyTorch's nn.Module.
encoder
: Our Vision Transformer model.decoder_dim
: The dimension of the decoder's embedding space (512 for example).masking_ratio
: The fraction of patches to mask (article found that 75% is optimal).- Other decoder configurations like
depth
,heads
, andhead
dimensions
which are standards for Transformers. - We assert that the masking ratio is between 0 and 1.
Patches:
# Save the encoder (a Vision Transformer to be trained)
self.encoder = encoder
# Extract the number of patches and the encoder's dimensionality from the positional embeddings
num_patches, encoder_dim = encoder.pos_embedding.shape[-2:]
# Separate the patch embedding layers from the encoder
# The first layer converts the image into patches
self.to_patch = encoder.to_patch_embedding[0]
# The remaining layers embed the patches
self.patch_to_emb = nn.Sequential(*encoder.to_patch_embedding[1:])
What's Going On?
• We store the encoder and extract essential information like the number of patches and the encoder's output dimension.
We separate the patch embedding process:
self.to_patch:
This layer splits the image into smaller patches.self.patch_to_emb:
This embeds each patch into a vector space.
# Determine the dimensionality of the pixel values per patch
pixel_values_per_patch = encoder.to_patch_embedding[2].weight.shape[-1]
- We calculate how many pixel values are in each patch, which we'll need later.
Setting up the decoder
self.enc_to_dec
: If the encoder and decoder have different dimensions, we map them accordingly. Usually the encoder is large and has a higher dimension (for example 1024), while the decoder can be more shallow and with a smaller dimension (for example 512), but we need an adapter to map the encoder's dimension to the decoder's dimensionself.mask_token:
A learnable token that represents the masked patches for the decoder. This is the token that the decoder sees when a patch is masked.- We initialize the decoder transformer and other layers needed for reconstruction.
self.decoder = Transformer(
dim=decoder_dim,
depth=decoder_depth,
heads=decoder_heads,
dim_head=decoder_dim_head,
mlp_dim_ratio=4
)
# Positional embeddings for the decoder tokens
self.decoder_pos_emb = nn.Embedding(num_patches, decoder_dim)
# Linear layer to reconstruct pixel values from decoder outputs
self.to_pixels = nn.Linear(decoder_dim, pixel_values_per_patch)
So Far, your MAE Class should be initialized like this:
class MAE(nn.Module):
def __init__(
self,
*,
encoder,
decoder_dim,
masking_ratio=0.75,
decoder_depth=1,
decoder_heads=8,
decoder_dim_head=64
):
super().__init__()
# Ensure the masking ratio is valid
assert 0 < masking_ratio < 1, 'masking ratio must be between 0 and 1'
self.masking_ratio = masking_ratio
# Save the encoder (a Vision Transformer to be trained)
self.encoder = encoder
# Extract the number of patches and the encoder's dimensionality from the positional embeddings
num_patches, encoder_dim = encoder.pos_embedding.shape[-2:]
# Separate the patch embedding layers from the encoder
# The first layer converts the image into patches
self.to_patch = encoder.to_patch_embedding[0]
# The remaining layers embed the patches
self.patch_to_emb = nn.Sequential(*encoder.to_patch_embedding[1:])
# Determine the dimensionality of the pixel values per patch
pixel_values_per_patch = encoder.to_patch_embedding[2].weight.shape[-1]
# Set up decoder parameters
self.decoder_dim = decoder_dim
# Map encoder dimensions to decoder dimensions if they differ
self.enc_to_dec = (
nn.Linear(encoder_dim, decoder_dim)
if encoder_dim != decoder_dim
else nn.Identity()
)
# Learnable mask token for masked patches
self.mask_token = nn.Parameter(torch.randn(decoder_dim))
# Define the decoder transformer
self.decoder = Transformer(
dim=decoder_dim,
depth=decoder_depth,
heads=decoder_heads,
dim_head=decoder_dim_head,
mlp_dim_ratio=4
)
# Positional embeddings for the decoder tokens
self.decoder_pos_emb = nn.Embedding(num_patches, decoder_dim)
# Linear layer to reconstruct pixel values from decoder outputs
self.to_pixels = nn.Linear(decoder_dim, pixel_values_per_patch)
Great! Now let's see how we use these different parts in the forward pass, it's a bit like assembling like a puzzle.
Forward Pass
Let's walk through the forward function, which defines how our model processes input data.
def forward(self, img):
device = img.device
# Convert the input image into patches
patches = self.to_patch(img) # Shape: (batch_size, num_patches, patch_size)
batch_size, num_patches, *_ = patches.shape
# Embed the patches using the encoder's patch embedding layers
tokens = self.patch_to_emb(patches) # Shape: (batch_size, num_patches, encoder_dim)
The beginning is very standard, we just have to decompose the "patchifying the image" operation with the "projection to tokens" operation, because we use the original patches as a ground truth to compute the loss.
- The forward method takes an image tensor img as input.
- We get the device (CPU or GPU) where the tensor is located.
- We split the image into patches.
- We obtain the
batch size
andnumber of patches.
- Each patch is embedded into a vector.
Positional Encoding:
# Add positional embeddings to the tokens
if self.encoder.pool == "cls":
# If using CLS token, skip the first positional embedding
tokens += self.encoder.pos_embedding[:, 1 : num_patches + 1]
elif self.encoder.pool == "mean":
# If using mean pooling, use all positional embeddings
tokens += self.encoder.pos_embedding.to(device, dtype=tokens.dtype)
- We add positional information to each token so the model knows where each patch came from. If there is an additional
CLS
token, we need to skip it because it is not part of the image.
Masking and Encoding
Now we come to the most interesting part, masking the image.
# Determine the number of patches to mask
num_masked = int(self.masking_ratio * num_patches)
# Generate random indices for masking
rand_indices = torch.rand(batch_size, num_patches, device=device).argsort(dim=-1)
masked_indices = rand_indices[:, :num_masked]
unmasked_indices = rand_indices[:, num_masked:]
- We calculate how many patches we'll mask based on our masking ratio.
- We generate a random permutation for each sequence of patch.
- We define the
masked_indices
andunmasked_indices
accordingly
# Select the tokens corresponding to unmasked patches
batch_range = torch.arange(batch_size, device=device)[:, None]
tokens = tokens[batch_range, unmasked_indices]
# Select the original patches that are masked (for reconstruction loss)
masked_patches = patches[batch_range, masked_indices]
# Encode the unmasked tokens using the encoder's transformer
encoded_tokens = self.encoder.transformer(tokens)
- We select the
masked_patches
with the correspondingmasked_indices
we have just defined. - We keep only the tokens for the unmasked patches for encoding.
Decoding
Now let's jump into the most exciting but also hardest part, decoding!
# Map encoded tokens to decoder dimensions if necessary
decoder_tokens = self.enc_to_dec(encoded_tokens)
# Add positional embeddings to the decoder tokens of unmasked patches
unmasked_decoder_tokens = decoder_tokens + self.decoder_pos_emb(unmasked_indices)
# Create mask tokens for the masked patches and add positional embeddings
mask_tokens = repeat(self.mask_token, 'd -> b n d', b=batch_size, n=num_masked)
mask_tokens = mask_tokens + self.decoder_pos_emb(masked_indices)
# Initialize the full sequence of decoder tokens
decoder_sequence = torch.zeros(
batch_size, num_patches, self.decoder_dim, device=device
)
# Place unmasked decoder tokens and mask tokens in their original positions
decoder_sequence[batch_range, unmasked_indices] = unmasked_decoder_tokens
decoder_sequence[batch_range, masked_indices] = mask_tokens
# Decode the full sequence
decoded_tokens = self.decoder(decoder_sequence)
# Extract the decoded tokens corresponding to the masked patches
masked_decoded_tokens = decoded_tokens[batch_range, masked_indices]
- We adjust the encoded tokens to match the decoder's expected input size with
self.enc_to_dec
- We add positional embeddings to the decoder tokens.
- For the masked positions, we use the
mask token
and addpositional
embeddings
. - We reconstruct the full sequence by placing the unmasked and masked tokens back into their original positions.
- We pass the full sequence through the decoder.
- We extract only the decoded tokens that correspond to the masked patches.
# Reconstruct the pixel values from the masked decoded tokens
pred_pixel_values = self.to_pixels(masked_decoded_tokens)
# Compute the reconstruction loss (mean squared error)
recon_loss = F.mse_loss(pred_pixel_values, masked_patches)
return recon_loss
- We attempt to reconstruct the original pixel values of the masked patches.
- We calculate the L2 loss by comparing the reconstructed patches to the original masked patches.
Congratulations, you've made it!
Thanks for reading! Before you go:
For more awesome tutorials, check my compilation of AI tutorials on Github
GitHub – FrancoisPorcher/awesome-ai-tutorials: The best collection of AI tutorials to make you a…
You should get my articles in your inbox. Subscribe here.
If you want to have access to premium articles on Medium, you only need a membership for $5 a month. If you sign up with my link, you support me with a part of your fee without additional costs.
If you found this article insightful and beneficial, please consider following me and leaving a clap for more in-depth content! Your support helps me continue producing content that aids our collective understanding.
References
- Kaiming He, Xinlei Chen, Saining Xie, Yanghao Li, Piotr Dollár, and Ross B. Girshick. Masked Autoencoders Are Scalable Vision Learners. arXiv:2111.06377, 2021. https://arxiv.org/abs/2111.06377
- https://github.com/lucidrains/vit-pytorch