Training LLM, from Scratch, in Rust

Author:Murphy  |  View: 22167  |  Time: 2025-03-22 19:11:07

In this companion article, I'll show my implementation for training from scratch a GPT-like model, in Rust. No GPUs, only CPUs, with a performance 30 times better than the native C code.

Image by GoogleDeepMind on Unsplash

In my last article, I introduced the problem of matrix multiplication, how the attention algorithm uses matrix multiplication to perform an averaging process, and how to efficiently implement – or at least, for me – a matrix multiplication function in Rust with Blas.

In this new article, I want to show my first building block for implementing llm.c in Rust, namely, training a Gpt-like model from scratch using Rust. This has been my way of learning more and more about the Rust ecosystem and understanding how comparable is with C. In particular, I want my code to be able to train a GPT-like model, starting from GPT weights, using only CPUs— so no GPUs or TPUs. My aim is to understand how much we can push these models on simple laptops, and how much the Rust ecosystem can be used for this. Eventually, this code may also be useful to fine-tune GPT models with a given input corpus.

All the relevant pieces of code can be found here.

The take-home messages I'd like everyone to have are:

  • How to deal with memory parameters and GPT parameters in Rust
  • How the attention forward is implemented
  • How to leverage parallelism with Rayon and what's thread lock

1. Handle all the parameters and the memory layout

GPT-like models present a large number of parameters and tensors: embedding matrices, layer-norm parameters, query, key, value matrices, attention outputs, feed-forward layers outputs, and so on. If we're dealing with PyTorch this is all included automatically in the code paradigm, so no need to worry about independent tensor objects, or how these tensors will fit the memory. On the contrary, in our Rust implementation, we do need to be worried and manage all these parameters.

The parameters are all stored as vectors Vec, single vector, contiguous in memory so that we can have:

  • a simple memory approach – we can improve the cache locality, speeding up the matrix multiplication operations, loading and parameters-saving processes;
  • an easier interface that can be used with Blas' sgemm to achieve a better matrix multiplication;
  • easy read/write from disk;

On the other side, we must pay the price for this, and interacting with array slicing:

  • we must know the size of each array;
  • we need to deal carefully with the slicing, namely picking up the right piece of tensor, to avoid index-bounds error;
  • we must have an offset variable for each parameter. For example, wte takes up vocab_size*channels floats, so params_memory[0..vocab_size*channels] . Then, the next tensor wpe will take params_memory[vocab_size*channels..(vocab_size*channels + max_seq_len*channels)] :
let wte = &self.params_memory[0..vocab_size*channels];
let wpe = &self.params_memory[vocab_size*channels..(vocab_size*channels + max_seq_len*channels)];

Overall, the only risk is to slice correctly the params_memory array. If we are aware of sizes, we can't run in an invalid memory access and we have also a single source of truth in the params_memory variable.

The core function, for parameters and memory allocation, isgpt_build_from_checkpoint. From this, we're reading the input file file.read_f32::()>. LittleEndian works for reading and writing numbers, in either little-endian or big-endian bytes, directly to/from byte arrays. The parameters are then created as:

model.params.wte = model.params_memory[offset..offset+model.param_size[0]].to_vec();
offset += model.param_sizes[0];

2a. Encoder forward, encoder backward

The second step to consider in building GPT-like models is to create word and position embeddings, namely high-dimensional vector embeddings. This is done with the encoder_forward, which returns an activation tensor of size [B, T, C]. An important thing to remember is what the B, T and C dimensions mean. The input data is subdivided into chunks, or batches of size B. Each batch has a block size, T. Suppose to take a text corpus, each sentence may be a batch, and each sentence may be divided into blocks of size T. The channel size, C, is the "weirdest" one. Our thoughts go immediately to the image-processing world, where we have R, G and B channels. In our sentence corpus, the channel parameter is the dimensionality of the embeddings we're creating. In our case C = 768. This parameter is read directly in the checkpoint file:

let (max_t, v, l, nh, c) = (
    model_header[2] as usize,
    model_header[3] as usize,
    model_header[4] as usize,
    model_header[5] as usize,
    model_header[6] as usize,
);

In particular, for the encoder_forward we are processing an input vector inp, that contains the collection of IDs of our tokens, shape [B x T]. The word embedding matrix, wte has a size [V x C], where V is vocabulary size, so how many unique tokens the model can represent, in our case 50'000. On the other side, the positional embeddings wpe have a size [max_t x C], where max_t is the max sequence length, that's 1024.

To accommodate these values, we are using the Rust slicing method, for example, in the encoder_forward:

let out_start_idx = b_idx * t * c + t_idx * c; // slicing
let out_bt = &mut out[out_start_idx..out_start_idx + c]; // slicing
let ix = inp[b_idx * t + t_idx] as usize;  // take the input values
let wte_start_idx = ix * c; // slicing 
let wte_ix = &wte[wte_start_idx..wte_start_idx + c]; // take wte values
let wpe_start_idx = t_idx * c; // slicing 
let wpe_t = &wpe[wpe_start_idx..wpe_start_idx + c]; // take wpe values 
for i in 0..c {
    out_bt[i] = wte_ix[i] + wpe_t[i];
}

To appreciate the Rust slicing, consider B=2, T=3, C=4. This means that the output has a length B x T x C = 24 so that:

out: [ out[b=0,t=0,:], out[b=0,t=1,:], out[b=0,t=2,:],
       out[b=1,t=0,:], out[b=1,t=1,:], out[b=1,t=2,:] ]

where out[b, t, :] is 4 elements. Thus, for b=1, t=2, so second batch and third token, the slice starts at out_start_idx = b_idx x t x c + t_ids x c = 1 x 3 x 4 + 2 x 4 = 20

2b. LayerNorm and attention

GPT-like models have a normalization step in their architecture so that we can stabilize the training and improve the performance. LayerNorm makes sure every single layer has a normalized Gaussian distribution. We're normalising each vector along the channel dimension C, ensuring a zero mean and unit variance.

For implementing layernorm_forward we're using a variable eps = 1e-5f32; that prevents the code from dividing by zero when computing 1 / sqrt( var + eps ).

After the normalisation is performed, we can start talking about the attention layer. The attention in our code is a multi-head self-attention:

  • The input dimension C is split across each head with C/nh
  • We compute the query, Q, the keys K and values V for each token
  • We compute the attention score
  • Finally, we perform a weighted sum of the values based on the attention scores

It's worth remembering what attention does. Considering we have a B, T, and C input elements, what we want to do is to take up to T tokens in the input string and make the algorithm understand how they're "interconnected" with each other. For example, the fifth token should consider only the tokens before itself, so the first, second, third and fourth tokens. In this way, the flow goes always from the current token up to the previous timestamp token.

To understand how interconnected are all the tokens, we just need to perform an average of how many times the t-th tokens are likely to be connected with the previous (t – 1)-th tokens. To perform an efficient average we're using a matrix multiplication trick. In particular, we're employing three vectors that will help us in performing the average. The first two vectors Q and K are the query and the key vector. The query vector answers the question: "What am I looking for?", while the Key vector the question: "What do I contain?". Now, performing the dot product between K and Q will return how much these two vectors are aligned, that is an alignment between token-content (what do I contain?) and token-associations (what am I looking for?).

To make the code more efficient, we are employing self-attention heads. This means that each K and Q vector will have a size B, T, head_size. The calculation returns a weight vector of size B, T, T. This means that the weights will have B rows, as many as our batches. For each batch, we will have a square matrix T x T, as the size of our tokens. Thus, for each combination of t-th row and t-th column, we will have a "statistical" weight that says how likely it is to have these two tokens together.

The final step is to interrogate the weight vector W with the value V vector. The value vector is just a simple linear neural network layer that is applied to the input tokens. This comes straight after a softmax processing. The output will have a size of B, T, head_size. Our challenge here will be to concatenate all the channel dimensions for each head size dimension.

Let's go to the practical side. The input terms for our attention forward function are:

  • out the output buffer, this will be a tensor of size [B, T, C]
  • preatt a tensor to store the "pre-softmax" attention scores
  • att the tensor to store the final probabilities after softmax
  • inp the input features, from which we'll derive query, key and value vectors
  • b, t, c, nh the batch size, the sequence length, the total number of channels (vocabulary size) and the number of heads in the attention process, respectively

At first, we are preparing all the constants. The choice for c3 = c x 3 is for the final concatenation of Q, K, V vectors.

The main for-loop process all the heads, for all the heads we're cycling across all the tokens, and then we're cycling over all the batches. The offset is computed again as:

let query_start = b_idx * t * c3 + t_idx * c3 + h * hs;
let preatt_start = b_idx * nh * t * t + h * t * t + t_idx * t;
let att_start = b_idx * nh * t * t + h * t * t + t_idx * t;

In this way it is possible to extract the query vector:

let query_vec = &inp[query_start..query_start + hs];

In particular, we have the hs -dimensional query vector for the current head and token. Remember, the query represents "what this token is looking for" in the other previous tokens.

Then, we construct the keys matrix:

let mut keys_mat = Vec::with_capacity((t_idx + 1) * hs);
for t2 in 0..=t_idx {
    let key_start = b_idx * t * c3 + t2 * c3 + h * hs + c; // +c to skip Q and access K
    keys_mat.extend_from_slice(&inp[key_start..key_start + hs]);
}

Here with_capacity constructs a new, empty Vec with at least the specified capacity. The vector will be able to hold at least capacity elements without reallocating. If capacity is 0, the vector will not allocate new elements. We're gathering all the keys to the current timestep t_idx + 1 – remember, the keys are hs-dimensional, as we can see from key_start.

Following, we proceed with the computation of the pre-attention score using Blas:

let mut preatt_row = vec![0.0f32; t_idx + 1];

unsafe {
    sgemm(
        Layout::RowMajor,
        Transpose::None,
        Transpose::None,
        (t_idx + 1) as i32,
        1,
        hs as i32,
        1.0,
        &keys_mat,
        hs as i32,
        query_vec,
        1,
        0.0,
        &mut preatt_row,
        1,
    );
}

the matrix multiplication here is (t_idx + 1) x hs * hs x 1 = (t_idx + 1) x 1 which gives the alignment score on how well the current token's query matches each previous token's key (see here for a more thorough explanation). These are logits and they are normalised with softmax, and stored in the array att.

Finally, we have the matrix multiplication between the value vector and the attention scores. This gives us a weighted sum so that for each token we know the score with all the other previously seen tokens.

3. Rayon and trying to leverage some parallelism

Before jumping into the real use of the code, I'd like to spend some words on Rayon, and how we can leverage data parallelism with it.

Rayon is a data-parallelism library, that allows us to run easily on multiple threads. As we saw previously in my post about matmul, we can use parallel iterators par_iter(), par_chunks() and par_chunks_mut(). These iterators can partitionate your data load directly on all the needed threads, without having you do the raw and dirty job. This gives us some simplicity in the usage and safety.

You may see in the code lines like:

out.par_chunks_mut(oc)
   .for_each(|row| {
       for (o, val) in row.iter_mut().enumerate() {
           *val += bias[o];
       }
   });

for row in out.chunks_mut(oc) {
    // ...
}

Rayon splits the out array into chunks of size oc and processes them in parallel threads. Each thread gets a separate chunk to work on so that there's no overlap or contention over the same data. This is something we could add to layernorm functions, as well as encoder functions, as we could deal with bigger datasets, and ensure a better parallelization.

However, not all that glitter is gold. Some operations, like the accumulation of gradients into a single array, or summing up statistics across multiple threads, require a shared state. The shared state implies that we cannot split the data, but we need to have all the data present at the same time, so we need synchronization. Achieving a shared state is complicated, as we need to prevent threads from writing to the same memory address at the same time without coordination. For this reason, we need Mutex. Mutex provides a mutual exclusion so that only one thread can lock the mutex at a time, ensuring that it's the only thread that is modifying the contained data.

use std::sync::Mutex;

let shared_data = Mutex::new(vec![0.0f32; size]);

(batches_in_parallel).for_each(|batch| {
    let mut guard = shared_data.lock().unwrap();
    for (g, val) in guard.iter_mut().zip(batch) {
        *g += val;
    }
});

If you see my attention_backward function you'll see it's split into multiple subchunks. This is mainly to avoid the error Cannot borrow as mutable more than once at a time. Moreover, here I am strongly using Rayon and Mutex, to allow some parallelism in the process.

As a matter of fact, in the backward pass, we need to compute gradients with respect to the input, dinp, to the pre-softmax attention scores, dpreatt, and to the attention probabilities datt. This is done over the entire batch, thus, no surprise, we do need to parallelize, to avoid having bottlenecks (you'll see below that this is the most time-consuming step). What we want is a parallel process for batch and for each attention head, so we can process these independently. However, always because I am avoiding the error Cannot borrow as mutable more than once at a time we need each thread to compute its local results, and merge all of these into final global gradients. To do that I need to use Mutex

let global_dinp = Mutex::new(vec![0.0f32; dinp.len()]);
let global_datt = Mutex::new(vec![0.0f32; datt.len()]);
let global_dpreatt = Mutex::new(vec![0.0f32; dpreatt.len()]);

So that I can use Rayon for the parallel work in the loop:

(0..b).into_par_iter().for_each(|b_idx| {
    let mut local_dinp = vec![0.0f32; dinp.len()];
    let mut local_datt = vec![0.0f32; datt.len()];
    let mut local_dpreatt = vec![0.0f32; dpreatt.len()];

});

to have each thread to work on a different b_idx, so that we can compute gradients locally. All of this is done in isolation, so each thread work on its local array.

After the local computations, we need to combine all the array into the global gradient array:

{
    let mut g_dinp = global_dinp.lock().unwrap();
    g_dinp
        .iter_mut()
        .zip(local_dinp.iter())
        .for_each(|(g, l)| *g += l);
}

This lock mechanism, prevents from having threads that are interleaving their writes inconsistently. When a thread holds the mutex lock, it has exclusive access.

In the final step we're copying the slices to that we propagate the local results to the final arrays

dinp.copy_from_slice(&global_dinp.lock().unwrap());
datt.copy_from_slice(&global_datt.lock().unwrap());
dpreatt.copy_from_slice(&global_dpreatt.lock().unwrap());

Time for fun: code, performance and inference!

It's now time to play with the code. All these calculations have run on a MacBook Pro, M2, 16 GB memory.

First, make sure to download the needed data with python prepro_tinyshakespeare.py. This will download the input corpus in a data folder. The text is converted from text to input training and validation tokens (tiny_shakespeare_train.bin and tiny_shakespeare_val.bin respectively). The text is tokenised with GPT-2 tokenizer. Then, you can build the rust code with:

cd Llm 
bash build.sh 

After 2000 steps you may have an inference output similar to:

3792, Is
340,  it
922,  good
11, ,
611,  if
345,  you
423,  have
26246,  pity
11, ,
284,  to
423,  have
281,  an
45618,  ornament
11, ,
257,  a
1486,  design
11, ,
198,

2514, To
9280,  dance
11, ,
7365,  bat
258, he
11, ,
18044,  breathe
290,  and
4545,  teach
30, ?
440,  O
11, ,
611,  if
340,  it
307,  be
2081,  true
11, ,
198,

1026, It
318,  is
2081,  true
356,  we
743,  may
307,  be
991,  still
2877,  living
11, ,
611,  if
340,  it
307,  be
2081,  true
25, :
198,

46, O
11, ,
2652,  stay
11, ,
393,  or
314,  I
2236,  shall
307,  be
2636,  dead
13, .
628,

where I am printing out the token ID and its text value. The code has run on 16 threads. To select the number of threads, you can modify this line in the code and this in the bash build.

Fig. 1 shows the forward and backward pass time for each step. The times are in ms. Overall we can see that the forward pass has a decent optimisation, so that the average time is 272.01 +/- 57.71 ms. Some work must be done to make the backward more efficient, as it attains an average timing of 472.63 +/- 51.75 ms. These timings are 30 times better than the original commit of Karpathy – used as my main source of inspiration for the Rust code, which takes an average of 30 seconds to perform a single step.

Fig.1: Timings for the forward and backward pass for LLM training. The times are in milliseconds (ms), and the code has run for 2000 steps against the tiny Shakespeare dataset. [ Image by the author ]

At the same time, we can measure and track the training loss, as fig. 2 displays. Overall there's a trend, that goes from an initial average of 4.5 to 3.2 around the last steps.

Fig.2: Training loss after 2000 steps. [ Image by the author ]

A further example of inference, generated after 2000 steps, after using the GPT-2 tokenised

Tis come; 
you'll bear it, 
this fierce protest or.. 

JULIET: We will. O' the loaning makers.
First watch man: You and your lads, your actions must be controll'd by Sir John

It may not be the best outcome from an LLM, but this comes straight after 2000 steps, just after 30 minutes of training, as a fine-tuning result on the tiny Shakespeare dataset.


Conclusions

If you've arrived here, thanks very much for reading my article. I hope you had a comprehensive look at my code, and you're ready to finetune GPT models.

The article has shown my way to get deeper into Rust and learn how to optimize the code for training a GPT-like model. In particular, we learned:

  • how to implement a GPT-like model in Rust, what are the major blockers (e.g. matrix multiplication, setting up parameters), and how to deal with memory management, tensor slicing and parallelisation.
  • how to leverage Rayon, using data-parallel loops, synchronising data with Mutex.
  • The goal of the project was to run GPT training on a simple CPU-based laptop, and even on this hardware, the Rust code looks way better than the C counterpart. Many more projects can start from this initial seed, from fine-tuning to achieving better optimization in Rust.

If you want to get in touch with me, you can drop an email to [email protected].

Tags: Gpt Hands On Tutorials Llm Machine Learning Rust

Comment