Can Transformers Learn to Strategize?

Author:Murphy  |  View: 25958  |  Time: 2025-03-23 12:45:43

Although most board games appeal to the use of convolutional neural nets or other geometrically inspired architectures, the fact that we can actually represent board states as strings begs the question of whether transformers can naturally be applied to the game. Here, we'll see if we can answer this in the context of a simple game of tic-tac-toe. Whilst this may not seem very practical (almost everyone knows that there is a closed-form Nash equilibrium strategy in this game that is very simple to learn), it is a useful testbed for our questions. The reason is that the game is simple enough that we can easily train a transformer to play it, but complex enough that it is not immediately obvious what the best strategy is.

Implementing the game

We'll start by implementing a TicTacToe class. This is fairly straightforward. We want to be able to represent the board as a string of 9 characters, one for each square. We'll use X for the first player, O for the second player, and - for an empty square. We'll also keep track of whose turn it is to play next, and whether the game is over or not. We'll also keep track of the winner, if there is one. Finally, we'll include a nice method to print the board so we don't have to stare at strings when debugging.

class TicTacToe:
    def __init__(self):
        # Initialise an empty board
        self.board = ['-' for _ in range(9)]
        self.current_player = 'X'  # X will start

    def make_move(self, position):
        """Make a move on the board."""
        if self.board[position] == '-':
            self.board[position] = self.current_player
            self.switch_player()
            return True
        else: return False  # illegal move

    def switch_player(self):
        """Switch the current player."""
        self.current_player = 'O' if self.current_player == 'X' else 'X'

    def check_winner(self):
        """Check if there is a winner."""
        # Rows, columns, diagonals
        winning_positions = [
            [0, 1, 2], [3, 4, 5], [6, 7, 8],  # Rows
            [0, 3, 6], [1, 4, 7], [2, 5, 8],  # Columns
            [0, 4, 8], [2, 4, 6]  # Diagonals
        ]

        for positions in winning_positions:
            values = [self.board[pos] for pos in positions]
            if values[0] == values[1] == values[2] and values[0] != '-':
                return values[0]

        return None  # No winner yet

    def is_draw(self):
        """Check if the game is a draw."""
        return all(cell != '-' for cell in self.board)

    def get_board_string(self):
        """Get the current board state as a string."""
        return ''.join(self.board)

    def get_legal_moves(self):
        """Get the positions of all legal moves."""
        return [i for i, cell in enumerate(self.board) if cell == '-']

    def pretty_print_board(self):
        """Pretty-print the board."""
        for i in range(0, 9, 3):
            print(f"{self.board[i]} | {self.board[i+1]} | {self.board[i+2]}")
            if i < 6:
                print("- "*5)

# Test the pretty_print_board method
tic_tac_toe = TicTacToe()
print("Initial board:")
tic_tac_toe.pretty_print_board()

# Make some moves
tic_tac_toe.make_move(0)
tic_tac_toe.make_move(4)
tic_tac_toe.make_move(8)
print("nBoard after some moves:")
tic_tac_toe.pretty_print_board()

Initial board:
- | - | -
- - - - - 
- | - | -
- - - - - 
- | - | -

Board after some moves:
X | - | -
- - - - - 
- | O | -
- - - - - 
- | - | X

Creating our training data

We want our transformer to take in a given board state and output a move, that move being an integer from 0 to 8 representing the square it wants to place its piece in. To do so, we'll create a dataset of board states and moves. We'll do this by simulating all the possible winning positions of our player, and then iterating over all the combinations of games that could have got us there. This means the transformer will learn what constitutes a good move in any given board state.

To implement this, the simulate_all_games function generates both the training and validation data. Specifically, the function simulates all possible Tic-Tac-Toe games, exploring every permutation of moves for both players (‘X' and ‘O'). This exhaustive simulation ensures that the model is trained on a comprehensive dataset, encompassing all conceivable game scenarios. During each simulated game, the function records not just the winning or drawing outcomes but also the sequence of board states and the moves that led to those outcomes. These board states and moves are then transformed into numerical representations, suitable for training our transformer. This ensures that the model learns not just to win, but to output an appropriate winning move from any given board state. You can think of this as akin to the ability of a regular language transformer to output an appropriate token given any length of context, from just one token (i.e. our starting board state) to the EOS token (i.e. our winning move yielding the final board state).

from copy import deepcopy
from itertools import product
import numpy as np
import torch

# Define character to integer mapping
chars = sorted(list(set('XO-')))
vocab_size = len(chars)
stoi = {ch: i for i, ch in enumerate(chars)}
itos = {0: '-', 1: 'X', 2: 'O'}
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join([itos[i] for i in l])

input_sequences = []
output_sequences = []

# Function to simulate all possible games recursively
def simulate_all_games(game, x_moves, o_moves):
    global input_sequences, output_sequences

    # Check if the game has reached a terminal state
    winner = game.check_winner()
    if winner == 'X' or game.is_draw():
        # Add the sequence of board states and moves leading to this win for 'X' or draw
        board = ['-' for _ in range(9)]
        for i, x_move in enumerate(x_moves):
            input_sequences.append(encode(''.join(board)))
            output_sequences.append(x_move)
            board[x_move] = 'X'
            if i < len(o_moves):
                board[o_moves[i]] = 'O'
        return
    elif winner == 'O':
        return  # We don't add these to our training data

    # Otherwise, continue simulating the game
    legal_moves = game.get_legal_moves()
    for move in legal_moves:
        # Create a copy of the game to simulate the move
        new_game = deepcopy(game)
        was_legal = new_game.make_move(move)

        # If the move was legal, continue simulating
        if was_legal:
            if new_game.current_player == 'X':
                simulate_all_games(new_game, x_moves + [move], o_moves)
            else:
                simulate_all_games(new_game, x_moves, o_moves + [move])

# Create an initial empty game
initial_game = TicTacToe()

# Simulate all possible games starting with 'X'
simulate_all_games(initial_game, [], [])

# Convert to PyTorch tensors
input_tensor = torch.tensor(input_sequences, dtype=torch.long)
output_tensor = torch.tensor(output_sequences, dtype=torch.long)

# Show some sample input-output pairs
print(input_tensor[:10], output_tensor[:10])

print("Number of input-output pairs:", len(input_sequences))

tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 2, 0, 0, 0, 0, 0, 0, 0],
        [1, 2, 1, 2, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 2, 0, 0, 0, 0, 0, 0, 0],
        [1, 2, 1, 2, 0, 0, 0, 0, 0],
        [1, 2, 1, 2, 1, 2, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 2, 0, 0, 0, 0, 0, 0, 0],
        [1, 2, 1, 2, 0, 0, 0, 0, 0]]) 
tensor([1, 3, 5, 1, 3, 5, 6, 1, 3, 5])

Number of input-output pairs: 658224

This gives us about 650,000 tensors to train on. These tensors look roughly correct, but it's a bit hard to tell without seeing the board visually. Let's recycle our print_board function to see what the some random board states, and the next board state given the move, look like:

def pretty_print_board(board: str):
    """Pretty-print the board."""
    for i in range(0, 9, 3):
        print(f"{board[i]} | {board[i+1]} | {board[i+2]}")
        if i < 6:
            print("- "*5)

rand_idx = torch.randint(len(input_tensor), (1,))[0]
random_game = input_tensor[rand_idx].tolist()
print("Current game state:")
decoded_game = decode(random_game)
pretty_print_board(decoded_game)
print( )

move = output_tensor[rand_idx].item()
decoded_game = decoded_game[:move] + 'X' + decoded_game[move+1:]
print("New game state:")
pretty_print_board(decoded_game)

Current game state:
X | - | -
- - - - - 
- | X | O
- - - - - 
- | O | -

New game state:
X | - | X
- - - - - 
- | X | O
- - - - - 
- | O | -

This seems plausible, but I've noticed that some games have a winning move available, and yet the simulation makes a different move (that still ends up in a win). This occurred in the example above. Let's change the simulate_all_games function to stop searching once we get to at least one potential winning move.

input_sequences = []
output_sequences = []

def simulate_all_games(game, x_moves, o_moves):
    global input_sequences, output_sequences

    # Check if the game has reached a terminal state
    winner = game.check_winner()
    if winner == 'X' or game.is_draw():
        # Add the sequence of board states and moves leading to this win for 'X' or draw
        board = ['-' for _ in range(9)]
        for i, x_move in enumerate(x_moves):
            input_sequences.append(encode(''.join(board)))
            output_sequences.append(x_move)
            board[x_move] = 'X'
            if i < len(o_moves):
                board[o_moves[i]] = 'O'
        return
    elif winner == 'O':
        return  # We don't add these to our training data

    # Before simulating further moves, check if a winning move is available
    legal_moves = game.get_legal_moves()
    for move in legal_moves:
        test_game = deepcopy(game)
        test_game.make_move(move)
        if test_game.check_winner() == game.current_player:
            # This move is a winning move, so we make it and end further simulation
            if test_game.current_player == 'X':
                simulate_all_games(test_game, x_moves + [move], o_moves)
            else:
                simulate_all_games(test_game, x_moves, o_moves + [move])
            return  # End further exploration for this branch

    # If no immediate winning move is found, continue simulating the game
    for move in legal_moves:
        # Create a copy of the game to simulate the move
        new_game = deepcopy(game)
        was_legal = new_game.make_move(move)

        # If the move was legal, continue simulating
        if was_legal:
            if new_game.current_player == 'X':
                simulate_all_games(new_game, x_moves + [move], o_moves)
            else:
                simulate_all_games(new_game, x_moves, o_moves + [move])

# Create an initial empty game
initial_game = TicTacToe()

# Simulate all possible games starting with 'X'
simulate_all_games(initial_game, [], [])

# Convert to PyTorch tensors
input_tensor = torch.tensor(input_sequences, dtype=torch.long)
output_tensor = torch.tensor(output_sequences, dtype=torch.long)

Finally, let's see how many moves we have to train on:

print("Number of input-output pairs:", len(input_sequences))

Number of input-output pairs: 147104

About 150,000 examples. This seems like a reasonable start.

Transformer architecture with multi-head attention

Attention is a mechanism whereby the model can learn to focus on certain parts of the input sequence when making predictions. The transformer architecture uses multi-head self-attention, which means that the model learns to attend to different parts of the input sequence in different ways. This is useful because it allows the model to learn different relationships between the input and output sequences. For example, it might learn to attend to the first token in the input sequence when predicting the first token in the output sequence, but attend to the last token in the input sequence when predicting the second token in the output sequence. This is a powerful mechanism that allows the model to learn complex relationships between the input and output sequences.

But how does this actually work? From the original Attention is all you need paper, attention defined over a query matrix Q, a key matrix K, and a value matrix V is defined as:

where we divide by sqrt{d_k} to ensure an appropriate variance in the softmax. Let's break down what's actually happening here. Suppose we have an input of dimension (B,T,C) where B is the batch-size, T is the sequence length, and C is the number of channels. We can think of this as a batch of B sequences of length T, each with C channels:

import torch.nn as nn
import torch.nn.functional as F

torch.manual_seed(1337)
B,T,C = 4,8,32 # batch, time, channels
x = torch.randn(B,T,C)

Then, to implement a single head of self-attention, we need to create the query, the key, and the value. In reality, these are linear layers with a certain head_size, which is just how wide we want the linear layer to be. We don't include a bias term because we don't want to learn a bias term for the attention.

wei = q @ k.transpose(-2,-1) # (B, T, 16) @ (B, 16, T) -> (B, T, T)

However, if you break down what we're actually doing, we're predicting the next word in a sequence of words. Since we don't want to cheat and use parts of the sequence we haven't seen yet (since we can't do this during generation), we need to mask out the parts of the sequence we haven't seen yet. We do this by creating a mask of shape (T,T) where T is the sequence length, and then setting all the values in the upper triangle to negative infinity. This ensures that the softmax will be 0 for all the masked values, and so the model will not attend to them.

Finally, we multiply the attention weights by the value matrix to get the output of the attention layer. This is the output of a single head of self-attention. We can then repeat this process for as many heads as we want, and then concatenate the outputs of each head to get the final output of the multi-head self-attention layer.

tril = torch.tril(torch.ones((T,T)))
wei = wei.masked_fill(tril==0, float("-inf"))
wei = F.softmax(wei, dim=-1)

v = value(x) # (B,T,16)
out = wei @ v
out.shape

And that's a single head of self-attention. To create multiple heads, we simply repeat this process multiple times, and then concatenate the outputs of each head to get the final output of the multi-head self-attention layer. We also add residual connections in order to improve our ability to optimise this relatively deep model. For a full walkthrough of similar code and the mechanisms behind decoder-only transformers, I highly recommend Andrej Karpathy's nanoGPT lecture.

from tqdm import tqdm
import torch.nn as nn
import torch.nn.functional as F

# Hyperparameters
batch_size = 128  # How many independent sequences will we process in parallel?
block_size = 9  # The size of the tic-tac-toe board
max_iters = 10000
eval_interval = 500
learning_rate = 1e-3
device = 'mps' if torch.backends.mps.is_available() else 'cpu'
eval_iters = 100
n_embd = 32  # Reduced the embedding size
n_head = 2  # Reduced the number of heads
n_layer = 2  # Reduced the number of layers
dropout = 0.1

print(f'Training on {device}')

# Initialize random seed
torch.manual_seed(1337)

# Split into training and validation sets
n = int(0.90 * len(input_tensor))  # 90% for training
train_input = input_tensor[:n]
train_output = output_tensor[:n]
val_input = input_tensor[n:]
val_output = output_tensor[n:]

# Updated data loading function
def get_batch(split):
    input_data = train_input if split == 'train' else val_input
    output_data = train_output if split == 'train' else val_output
    # Choose index locs for batch_size sequences
    ix = torch.randint(len(input_data) - block_size + 1, (batch_size,))
    # Get the input and output sequences
    x = input_data[ix]
    y = output_data[ix]
    x, y = x.to(device), y.to(device)
    return x, y

@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

class Head(nn.Module):
    """ one head of self-attention """

    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # input of size (batch, time-step, channels)
        # output of size (batch, time-step, head size)
        B,T,C = x.shape
        k = self.key(x)   # (B,T,hs)
        q = self.query(x) # (B,T,hs)
        # compute attention scores ("affinities")
        wei = q @ k.transpose(-2,-1) * k.shape[-1]**-0.5 # (B, T, hs) @ (B, hs, T) -> (B, T, T)
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)
        wei = F.softmax(wei, dim=-1) # (B, T, T)
        wei = self.dropout(wei)
        # perform the weighted aggregation of the values
        v = self.value(x) # (B,T,hs)
        out = wei @ v # (B, T, T) @ (B, T, hs) -> (B, T, hs)
        return out

class MultiHeadAttention(nn.Module):
    """ multiple heads of self-attention in parallel """

    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(head_size * num_heads, n_embd)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out

class FeedFoward(nn.Module):
    """ a simple linear layer followed by a non-linearity """

    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)

class Block(nn.Module):
    """ Transformer block: communication followed by computation """

    def __init__(self, n_embd, n_head):
        # n_embd: embedding dimension, n_head: the number of heads we'd like
        super().__init__()
        head_size = n_embd // n_head
        self.sa = MultiHeadAttention(n_head, head_size)
        self.ffwd = FeedFoward(n_embd)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    def forward(self, x):
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x

class Transformer(nn.Module):

    def __init__(self):
        super().__init__()
        # each token directly reads off the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embd) # final layer norm
        self.lm_head = nn.Linear(n_embd, 9)

        # better init, not covered in the original GPT video, but important, will cover in followup video
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx, targets=None):
        B, T = idx.shape

        # idx and targets are both (B,T) tensor of integers
        tok_emb = self.token_embedding_table(idx)  # (B,T,C)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device))  # (T,C)
        x = tok_emb + pos_emb  # (B,T,C)
        x = self.blocks(x)  # (B,T,C)
        x = self.ln_f(x)  # (B,T,C)
        logits = self.lm_head(x)  # (B,T,vocab_size)
        # Take the logits corresponding to the last time step T
        logits = logits[:, -1, :]  # Now logits is (B, 9)

        if targets is None:
            loss = None
        else:
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_tokens):
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            # crop idx to the last block_size tokens
            idx_cond = idx[:, -block_size:]
            # get the predictions
            logits, loss = self(idx_cond)
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx

To make sure our architecture works as intended, let's try passing a single batch through.

xb, yb = get_batch('train')
print(xb.shape, yb.shape)
m = Transformer().to(device)
logits, loss = m(xb, yb)
print(logits.shape)
print(f"Loss: {loss.item():.3f}")

torch.Size([128, 9]) torch.Size([128])
torch.Size([128, 9])
Loss: 2.203

A good step when doing this initial forward pass is to test whether the loss is about equivalent to what we'd expect for random inputs. Since we have 9-dimensional logits and we're using cross-entropy loss, which is equivalent to the negative log-likelihood of the correct class, we'd expect the loss to be about:

Training the model

Using quite a small transformer (about 25,000 parameters), we achieve the following loss (note that I'm using a small amount of weight decay and dropout):

model = Transformer()
model = model.to(device)

# Print the number of parameters in the model
print(sum(p.numel() for p in model.parameters()) / 1e6, 'M parameters')

optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=1e-4)
train_loss_history = []
val_loss_history = []

# Training loop
for iter in tqdm(range(max_iters)):
    # Evaluate the loss on train and val sets occasionally
    if iter % eval_interval == 0 or iter == max_iters - 1:
        losses = estimate_loss()
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
        val_loss_history.append(losses['val'])

    # Sample a batch of data
    xb, yb = get_batch('train')

    # Evaluate the loss
    logits, loss = model(xb, yb)
    train_loss_history.append(loss.item())
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

0.025961 M parameters
  0%|          | 5/10000 [00:00<24:42,  6.74it/s]  
step 0: train loss 2.2033, val loss 2.2106
  5%|▌         | 504/10000 [00:14<12:01, 13.15it/s]
step 500: train loss 1.9162, val loss 2.0215
 10%|█         | 1008/10000 [00:27<08:27, 17.73it/s]
step 1000: train loss 1.7846, val loss 1.8570
 15%|█▌        | 1505/10000 [00:40<10:34, 13.39it/s]
step 1500: train loss 1.7370, val loss 1.7648
 20%|██        | 2007/10000 [00:53<07:35, 17.55it/s]
step 2000: train loss 1.7188, val loss 1.7770
 25%|██▌       | 2506/10000 [01:05<07:11, 17.36it/s]
step 2500: train loss 1.6957, val loss 1.7456
 30%|███       | 3006/10000 [01:18<06:35, 17.69it/s]
step 3000: train loss 1.6965, val loss 1.7448
 35%|███▌      | 3506/10000 [01:31<06:12, 17.41it/s]
step 3500: train loss 1.6961, val loss 1.7809
 40%|████      | 4005/10000 [01:43<07:41, 12.98it/s]
step 4000: train loss 1.6819, val loss 1.7256
 45%|████▌     | 4506/10000 [01:56<05:18, 17.24it/s]
step 4500: train loss 1.6892, val loss 1.7066
 50%|█████     | 5005/10000 [02:09<05:14, 15.88it/s]
step 5000: train loss 1.6846, val loss 1.7141
 55%|█████▌    | 5508/10000 [02:23<04:37, 16.19it/s]
step 5500: train loss 1.6835, val loss 1.6998
 60%|██████    | 6004/10000 [02:36<05:19, 12.51it/s]
step 6000: train loss 1.6828, val loss 1.7095
 65%|██████▌   | 6506/10000 [02:49<03:23, 17.13it/s]
step 6500: train loss 1.6722, val loss 1.7151
 70%|███████   | 7008/10000 [03:02<03:05, 16.17it/s]
step 7000: train loss 1.6656, val loss 1.7158
 75%|███████▌  | 7505/10000 [03:15<02:30, 16.54it/s]
step 7500: train loss 1.6672, val loss 1.7078
 80%|████████  | 8007/10000 [03:28<02:01, 16.38it/s]
step 8000: train loss 1.6808, val loss 1.7120
 85%|████████▌ | 8505/10000 [03:41<01:47, 13.94it/s]
step 8500: train loss 1.6733, val loss 1.7144
 90%|█████████ | 9007/10000 [03:54<00:56, 17.54it/s]
step 9000: train loss 1.6714, val loss 1.7031
 95%|█████████▌| 9506/10000 [04:07<00:28, 17.39it/s]
step 9500: train loss 1.6707, val loss 1.7073
100%|██████████| 10000/10000 [04:20<00:00, 38.43it/s]
step 9999: train loss 1.6664, val loss 1.7506

That doesn't seem very good. Let's plot it to see what's going on.

import matplotlib.pyplot as plt

def plot_transformer_loss(loss_history, val_loss_history):

    # Two horizontal figures side-by-side
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

    # First plot = training loss
    ax1.plot(loss_history, lw=0.5)
    # Plot moving average of loss
    window_size = 100
    ax1.plot(np.convolve(loss_history, np.ones(window_size) / window_size, mode='valid'), label='Moving average')
    ax1.set_xlabel('Iteration')
    ax1.set_ylabel('Cross-entropy Loss')
    ax1.set_title('Training Loss')
    ax1.legend()

    # Second plot = validation loss
    # Set marker style to be circles at each data point
    indices = np.arange(0, len(val_loss_history) * eval_interval, eval_interval)
    ax2.plot(indices, val_loss_history, marker='o')
    ax2.set_title('Validation Loss')
    ax2.set_xlabel('Iteration')

    plt.show()

plot_transformer_loss(train_loss_history, val_loss_history)
Initial losses for our transformer trained on the tic-tac-toe data (image by author).

We can play the transformer using this code to see if it's any good:

import random
import torch
from IPython.display import clear_output

def play_game(model, stoi, itos, device):
    game = TicTacToe()

    # Randomly decide who goes first
    game.current_player = random.choice(['X', 'O'])

    while game.check_winner() is None and not game.is_draw():
        #clear_output(wait=True)

        print(f"{game.current_player}'s turn.")
        game.pretty_print_board()

        current_board_str = game.get_board_string()

        if game.current_player == 'X':
            print("Model's turn...")
            current_board_encoded = torch.tensor([stoi[c] for c in current_board_str], dtype=torch.long).unsqueeze(0).to(device)
            logits, _ = model(current_board_encoded)

            # Move logits to cpu
            logits = logits.cpu()

            # Create a mask for legal moves and zero out logits for illegal moves
            legal_moves = game.get_legal_moves()
            mask = torch.zeros(9)
            mask[legal_moves] = 1
            masked_logits = logits * mask

            # Get the model's move
            predicted_move = masked_logits.argmax(dim=-1).item()

            # Make the model's move
            game.make_move(predicted_move)

        else:
            print("Your turn!")
            legal_moves = game.get_legal_moves()
            print("Legal moves:", legal_moves)

            user_move = int(input("Enter your move: "))
            if user_move in legal_moves:
                game.make_move(user_move)
            else:
                print("Illegal move. Try again.")
                continue

        winner = game.check_winner()
        if winner is not None:
            #clear_output(wait=True)
            print(f"{winner} wins!")
            game.pretty_print_board()
            break
        elif game.is_draw():
            #clear_output(wait=True)
            print("It's a draw!")
            game.pretty_print_board()
            break

O's turn.
- | - | -
- - - - - 
- | - | -
- - - - - 
- | - | -
Your turn!
Legal moves: [0, 1, 2, 3, 4, 5, 6, 7, 8]
X's turn.
- | - | -
- - - - - 
- | - | -
- - - - - 
- | - | O
Model's turn...
O's turn.
- | - | X
- - - - - 
- | - | -
- - - - - 
- | - | O
Your turn!
Legal moves: [0, 1, 3, 4, 5, 6, 7]
X's turn.
O | - | X
- - - - - 
- | - | -
- - - - - 
- | - | O
Model's turn...
O's turn.
O | - | X
- - - - - 
X | - | -
- - - - - 
- | - | O
Your turn!
Legal moves: [1, 4, 5, 6, 7]
O wins!
O | - | X
- - - - - 
X | O | -
- - - - - 
- | - | O

Okay, so I easily beat the model. Something is going wrong.

Improving the transformer

So at the moment, the transformer can't even reliably learn the simple winning moves from any given position. I can think of a few reasons for this:

  • The transformer is only training on the winning moves, so maybe it's not learning how to play when I play a decent strategy that means no winning moves are available. Theoretically, to counter this we should allow it to train on moves when the game is pre-destined to be a draw.
  • The transformer is overparametrised. Trying to get a few hundred thousand neurons to coordinate a simple strategy may take a long time to train and rely on grokking and other phenomenon to get into a generalisable part of the optimisation landscape.
  • The transformer is underparametrised. Maybe it needs more neurons to learn a good strategy. Saying this, I very much doubt that scaling up to a few million neurons will help if a few hundred thousand doesn't do the job.

State-space analysis

Before I move on, I want to analyse the number of neurons theoretically required to encode the entire winning strategy for tic-tac-toe, we need to consider the state space of the game and the complexity of the decision-making process.

In tic-tac-toe, the game board is a 3 x 3 grid, and each cell can be in one of three states: ‘X', ‘O', or empty (‘-‘). So, the total number of possible board states can be calculated as 3⁹ = 19683. However, not all of these are valid states in a real game; some of them are unreachable or illegal (e.g., a board with all ‘X's). The number of legal states is actually around 5,478, but for the purpose of this analysis, we'll consider the upper limit, i.e., 3⁹.

Each state requires a decision: where to place the next ‘X' (since we're considering a winning strategy for ‘X'). There are 9 possible positions, but the number of legal moves is often less than 9, depending on how many cells are already occupied. A neural network would have to map each possible board state to a correct move. One way to achieve this mapping is through a fully connected layer that takes the board state as input and outputs a probability distribution over the 9 possible moves. The input layer would have 3×3=9 neurons (one for each cell), and the output layer would have 9 neurons (one for each possible move). The hidden layers in between would perform the complex task of learning the winning strategy.

Given that the input layer has 9 neurons and the output layer has 9 neurons, the number of neurons in the hidden layers is what we're interested in. Theoretically, we could use a single hidden layer with 3⁹ neurons to map each possible state to a winning move. This would be an upper bound and is likely far more than what would actually be necessary due to the presence of unreachable/illegal states and the inherent symmetries in tic-tac-toe that reduce the actual number of unique states. So, in this upper-bound scenario, the total number of neurons would be: 9 + 19683 + 9 = 19701. This is a theoretical upper bound and the actual number could be much lower due to the factors mentioned earlier. Let's try implementing a simple feed-forward vanilla neural network to see how it does on our task. Rather than having one layer with thousands of neurons, we'll use three hidden layers.

import torch.nn as nn
import torch.nn.functional as F

device = 'cpu'

class TicTacToeNN(nn.Module):
    def __init__(self):
        super(TicTacToeNN, self).__init__()
        self.fc1 = nn.Linear(9, 16)  # Input layer to hidden layer 1
        self.fc2 = nn.Linear(16, 32)  # Hidden layer 1 to hidden layer 2
        self.fc3 = nn.Linear(32, 16)  # Hidden layer 2 to hidden layer 
        self.fc4 = nn.Linear(16, 9)  # Hidden layer 3 to output layer

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = self.fc4(x)
        return x

def get_batch(split):
    input_data = train_input if split == 'train' else val_input
    output_data = train_output if split == 'train' else val_output
    # Choose index locs for batch_size sequences
    ix = torch.randint(len(input_data) - block_size + 1, (batch_size,))
    # Get the input and output sequences
    x = input_data[ix].float()
    y = output_data[ix]
    x, y = x.to(device), y.to(device)
    return x, y

# Create an initial empty game
initial_game = TicTacToe()

# Simulate all possible games starting with 'X'
simulate_all_games(initial_game, [], [])

# Convert to PyTorch tensors
input_tensor = torch.tensor(input_sequences, dtype=torch.long)
output_tensor = torch.tensor(output_sequences, dtype=torch.long)

nn_model = TicTacToeNN()
nn_model.to(device)

# Print the number of parameters in the model
print(sum(p.numel() for p in nn_model.parameters()), 'parameters')

# Create a PyTorch optimizer
optimizer = torch.optim.AdamW(nn_model.parameters(), lr=learning_rate, weight_decay=1e-4)
train_loss_history = []
val_loss_history = []

# Training loop
max_iters = 1000000
for iter in tqdm(range(max_iters)):
    # Evaluate the loss on train and val sets occasionally

    # Sample a batch of data
    xb, yb = get_batch('train')
    # Evaluate the loss
    logits = nn_model(xb)
    # Calculate cross-entropy loss
    loss = F.cross_entropy(logits, yb)
    train_loss_history.append(loss.item())
    # Get the validation loss
    xb, yb = get_batch('val')
    logits = nn_model(xb)
    val_loss = F.cross_entropy(logits, yb)
    val_loss_history.append(val_loss.item())
    # Backpropagate and update the weights
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

1385 parameters
100%|██████████| 1000000/1000000 [08:08<00:00, 2048.42it/s]

And let's see what the loss looks like:

Training loss for a vanilla feed-forward neural network (image by author).

Clearly, we're saturating performance. There's something about the task and how we've set it up that prevents the models from learning an appropriate strategy. To shift tack a bit, I'm going to try giving the model training data which just consists of the optimal strategy.

Optimal strategy training data

Newell and Simon's 1972 tic-tac-toe program outlines the perfect strategy (to win or at least draw) if we choose the first available move from the following preference over moves:

  1. Win: If you have two in a row, play the third to get three in a row.
  2. Block: If the opponent has two in a row, play the third to block them.
  3. Fork: Create an opportunity where you can win in two ways.
  4. Block Opponent's Fork: we either create two in a row to force a defence (if this doesn't result in a fork for them), or we block their potential fork.
  5. Centre: Play the centre.
  6. Opposite Corner: If the opponent is in the corner, play the opposite corner.
  7. Empty Corner: Play an empty corner.
  8. Empty Side: Play an empty side.

Let's rewrite our data generator to obtain all possible moves according to this strategy. We'll also simulate all games for both possible players going first.

from copy import deepcopy

# Helper function to find if there's a winning move or a move that blocks the opponent from winning
def find_winning_or_blocking_move(board, player):
    winning_positions = [
        [0, 1, 2], [3, 4, 5], [6, 7, 8],  # Rows
        [0, 3, 6], [1, 4, 7], [2, 5, 8],  # Columns
        [0, 4, 8], [2, 4, 6]  # Diagonals
    ]
    for positions in winning_positions:
        values = [board[pos] for pos in positions]
        if values.count(player) == 2 and values.count('-') == 1:
            return positions[values.index('-')]
    return None

# Helper function for checking for fork opportunities
def find_fork_move(board, player):
    fork_move = None
    for i in range(9):
        if board[i] == '-':
            temp_board = board[:]
            temp_board[i] = player
            winning_moves = 0
            for j in range(9):
                if temp_board[j] == '-':
                    temp_board_2 = temp_board[:]
                    temp_board_2[j] = player
                    if find_winning_or_blocking_move(temp_board_2, player) is not None:
                        winning_moves += 1
            if winning_moves >= 2:
                fork_move = i
                break
    return fork_move

# Helper function to find the optimal move according to a defined strategy
def optimal_strategy(board, player):
    opponent = 'O' if player == 'X' else 'X'

    # 1. Win: If you have two in a row, play the third to get three in a row.
    win_move = find_winning_or_blocking_move(board, player)
    if win_move is not None:
        return win_move

    # 2. Block: If the opponent has two in a row, play the third to block them.
    block_move = find_winning_or_blocking_move(board, opponent)
    if block_move is not None:
        return block_move

    # 3. Fork: Create an opportunity where you can win in two ways.
    fork_move = find_fork_move(board, player)
    if fork_move is not None:
        return fork_move

    # 4. Block Opponent's Fork
    opponent_fork_move = find_fork_move(board, opponent)
    if opponent_fork_move is not None:
        return opponent_fork_move

    # 5. Center: Play the center.
    if board[4] == '-':
        return 4

    # 6. Opposite Corner: If the opponent is in the corner, play the opposite corner.
    corners = [(0, 8), (2, 6), (8, 0), (6, 2)]
    for corner1, corner2 in corners:
        if board[corner1] == opponent and board[corner2] == '-':
            return corner2

    # 7. Empty Corner: Play an empty corner.
    for corner in [0, 2, 6, 8]:
        if board[corner] == '-':
            return corner

    # 8. Empty Side: Play an empty side.
    for side in [1, 3, 5, 7]:
        if board[side] == '-':
            return side

# Function to simulate all games according to the optimal strategy
def simulate_all_games_optimal_v2(game, x_starts=True):
    global input_sequences, output_sequences

    # Check for terminal state
    winner = game.check_winner()
    if winner or game.is_draw():
        return

    # If it's X's turn, apply the optimal strategy and save the board state and move
    if game.current_player == 'X':
        move = optimal_strategy(game.board, 'X')
        if move is None:
            move = game.get_legal_moves()[0]  # fallback
        input_sequences.append(encode(''.join(game.board)))
        output_sequences.append(move)
        new_game = deepcopy(game)
        new_game.make_move(move)
        simulate_all_games_optimal_v2(new_game, x_starts)
    else:
        # If it's O's turn, explore all possible legal moves
        for move in game.get_legal_moves():
            new_game = deepcopy(game)
            new_game.make_move(move)
            simulate_all_games_optimal_v2(new_game, x_starts)

# Character to integer mapping
chars = sorted(list(set('XO-')))
vocab_size = len(chars)
stoi = {ch: i for i, ch in enumerate(chars)}
itos = {0: '-', 1: 'X', 2: 'O'}
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join([itos[i] for i in l])

# Reset and re-simulate
input_sequences = []
output_sequences = []

# 'X' starts
initial_game = TicTacToe()
simulate_all_games_optimal_v2(initial_game, True)

# 'O' starts
initial_game = TicTacToe()
initial_game.current_player = 'O'
simulate_all_games_optimal_v2(initial_game, False)

# Convert to Pytorch tensors
input_tensor = torch.tensor(input_sequences, dtype=torch.long)
output_tensor = torch.tensor(output_sequences, dtype=torch.long)

print("Number of input-output pairs:", len(input_sequences))

Number of input-output pairs: 1017

Let's retrain our model on our new training data.

# Hyperparameters
batch_size = 128  # How many independent sequences will we process in parallel?
block_size = 9  # The size of the tic-tac-toe board
max_iters = 10000
eval_interval = 500
learning_rate = 1e-3
device = 'mps' if torch.backends.mps.is_available() else 'cpu'
eval_iters = 100
n_embd = 32  # Reduced the embedding size
n_head = 2  # Reduced the number of heads
n_layer = 2  # Reduced the number of layers
dropout = 0.1

print(f'Training on {device}')

# Initialize random seed
torch.manual_seed(1337)

# Split into training and validation sets
n = int(0.90 * len(input_tensor))  # 90% for training
train_input = input_tensor[:n]
train_output = output_tensor[:n]
val_input = input_tensor[n:]
val_output = output_tensor[n:]

# Updated data loading function
def get_batch(split):
    input_data = train_input if split == 'train' else val_input
    output_data = train_output if split == 'train' else val_output
    # Choose index locs for batch_size sequences
    ix = torch.randint(len(input_data) - block_size + 1, (batch_size,))
    # Get the input and output sequences
    x = input_data[ix]
    y = output_data[ix]
    x, y = x.to(device), y.to(device)
    return x, y

# Initialize the model
model = Transformer()
model = model.to(device)

max_iters = 5000

# Print the number of parameters in the model
print(sum(p.numel() for p in model.parameters()) / 1e6, 'M parameters')

# Create a PyTorch optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=1e-4)
train_loss_history = []
val_loss_history = []

# Training loop
for iter in tqdm(range(max_iters)):
    # Evaluate the loss on train and val sets occasionally
    if iter % eval_interval == 0 or iter == max_iters - 1:
        losses = estimate_loss()
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
        val_loss_history.append(losses['val'])

    # Sample a batch of data
    xb, yb = get_batch('train')

    # Evaluate the loss
    logits, loss = model(xb, yb)
    train_loss_history.append(loss.item())
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

Training on mps
0.025961 M parameters
  0%|          | 6/5000 [00:00<10:15,  8.11it/s]  
step 0: train loss 2.2078, val loss 2.2166
 10%|█         | 505/5000 [00:13<05:23, 13.89it/s]
step 500: train loss 0.3063, val loss 0.6145
 20%|██        | 1005/5000 [00:26<04:40, 14.24it/s]
step 1000: train loss 0.0741, val loss 0.2259
 30%|███       | 1505/5000 [00:38<04:08, 14.05it/s]
step 1500: train loss 0.0368, val loss 0.1799
 40%|████      | 2005/5000 [00:51<03:36, 13.83it/s]
step 2000: train loss 0.0134, val loss 0.1589
 50%|█████     | 2504/5000 [01:04<02:57, 14.06it/s]
step 2500: train loss 0.0081, val loss 0.0884
 60%|██████    | 3008/5000 [01:17<01:56, 17.06it/s]
step 3000: train loss 0.0041, val loss 0.0521
 70%|███████   | 3505/5000 [01:29<01:46, 14.09it/s]
step 3500: train loss 0.0028, val loss 0.0855
 80%|████████  | 4005/5000 [01:42<01:10, 14.06it/s]
step 4000: train loss 0.0036, val loss 0.1125
 90%|█████████ | 4506/5000 [01:56<00:29, 16.68it/s]
step 4500: train loss 0.0014, val loss 0.0892
100%|██████████| 5000/5000 [02:08<00:00, 38.79it/s]
step 4999: train loss 0.0026, val loss 0.0721
New loss with the optimal training data (image by author).

So much better! We've not only learned the strategy but it generalises to the validation dataset (we get close to 0 loss on both train and val). I'm guessing this is due to the inherent symmetries in board states, and the transformer has learned a form of modular arithmetic over the board string to be invariant to board states.

Let's try playing the new transformer:

O's turn.
- | - | -
- - - - - 
- | - | -
- - - - - 
- | - | -
Your turn!
Legal moves: [0, 1, 2, 3, 4, 5, 6, 7, 8]
X's turn.
- | O | -
- - - - - 
- | - | -
- - - - - 
- | - | -
Model's turn...
O's turn.
X | O | -
- - - - - 
- | - | -
- - - - - 
- | - | -
Your turn!
Legal moves: [2, 3, 4, 5, 6, 7, 8]
X's turn.
X | O | -
- - - - - 
- | - | -
- - - - - 
- | O | -
Model's turn...
O's turn.
X | O | -
- - - - - 
- | X | -
- - - - - 
- | O | -
Your turn!
Legal moves: [2, 3, 5, 6, 8]
X's turn.
X | O | -
- - - - - 
- | X | -
- - - - - 
- | O | O
Model's turn...
O's turn.
X | O | -
- - - - - 
- | X | -
- - - - - 
X | O | O
Your turn!
Legal moves: [2, 3, 5]
X's turn.
X | O | O
- - - - - 
- | X | -
- - - - - 
X | O | O
Model's turn...
X wins!
X | O | O
- - - - - 
X | X | -
- - - - - 
X | O | O

It beat me! Used a fantastic fork. It seems our transformer has learned the optimal strategy.

Conclusion

I think the main takeaway here is that transformers are more than capable of learning optimal strategies for games. Whilst a vanilla neural net could likely have learned the same optimal strategy, the dynamic nature of attention means that it might be able to attend to longer sequences representing the game over time. These ideas naturally give rise to applying transformers in reinforcement learning settings. For instance, Janner et al. (2021) used transformers to model distributions over trajectories and a beam search as the planning algorithm.

Another thing I learned from this project was that the process of the human hand-holding the transformer through the optimal strategy obviously doesn't scale, particularly as games get more complex. For instance, Go isn't a "solved" game, so we couldn't do what we did above and provide it with the optimal strategy to train on. Instead, we'd have to use something like self-play to select good sequences to then train the transformer on. I hope to experiment with ideas like this in future.

Finally, there's a ripe field to be explored by looking at hierarchies of prediction and planning in transformers. As Ba et al. (2016) noted, Deep Learning has typically focused on methods to maintain temporary state in activation dynamics, whereas our brain seems to operate by regulating temporary state informatio via medium-term synaptic plasticity. In other words, there should be some form of working memory/prediction that operates between the next-token level and the long-term states in LSTMs for instance. The autoregressive decoder-only transformer architecture that forms the backbone for GPT-3 is a powerful model that can generate realistic-looking text simply by predicting one token into the future. However, if we anthropomorphise intelligence, we know that quick intuitive look-ahead (by one token) does not a genius make. In my mind, it would thus be interesting to try and provide a model with multiple hierarchies of prediction, where the model can learn to predict multiple tokens into the future. This would be akin to the human ability to plan ahead, and would likely be a useful skill for transformers to learn.

What does it mean to look ahead? There a few different avenues to explore here:

  • Look ahead in time: Can the transformer learn to predict not the next token in the sequence but the token two, three, or more steps ahead? Is predicting two tokens at once just equivalent to predicting one token, then predicting the next token, or is there some strategic benefit to predicting two tokens at a time? Does it force the transformer to think for longer?
  • Look ahead in space: There is evidence that humans produce coarse-grained semantic representations of tasks they are performing, and then use hierarchical modules to "fill in the gaps" in these coarse-grained representations. You might think of this as writing an essay: first you create a skeleton of points, then you fill in the thesis sentences for each paragraph, and finally you flesh out the details. Is it possible that Transformers can learn to do the same thing?

Regardless of whether these are useful questions or not, I hope this post gave some clarity and insight into exactly how we can reshape a problem to be suitable for a transformer. Happy fine-tuning!

References

  1. Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., & Polosukhin, I. (2017). Attention is all you need. Advances in neural information processing systems, 30.
  2. Janner, M., Li, Q., & Levine, S. (2021). Offline reinforcement learning as one big sequence modeling problem. Advances in neural information processing systems, 34, 1273–1286.
  3. Ba, J., Hinton, G. E., Mnih, V., Leibo, J. Z., & Ionescu, C. (2016). Using fast weights to attend to the recent past. Advances in neural information processing systems, 29.
  4. Andrej Karpathy. Let's build GPT: from scratch, in code, spelled out. https://www.youtube.com/watch?v=kCc8FmEb1nY&t=5076s

Tags: Deep Learning Editors Pick Game Theory Neural Networks Transformers

Comment