Can Transformers Learn to Strategize?
Although most board games appeal to the use of convolutional neural nets or other geometrically inspired architectures, the fact that we can actually represent board states as strings begs the question of whether transformers can naturally be applied to the game. Here, we'll see if we can answer this in the context of a simple game of tic-tac-toe. Whilst this may not seem very practical (almost everyone knows that there is a closed-form Nash equilibrium strategy in this game that is very simple to learn), it is a useful testbed for our questions. The reason is that the game is simple enough that we can easily train a transformer to play it, but complex enough that it is not immediately obvious what the best strategy is.
Implementing the game
We'll start by implementing a TicTacToe
class. This is fairly straightforward. We want to be able to represent the board as a string of 9 characters, one for each square. We'll use X
for the first player, O
for the second player, and -
for an empty square. We'll also keep track of whose turn it is to play next, and whether the game is over or not. We'll also keep track of the winner, if there is one. Finally, we'll include a nice method to print the board so we don't have to stare at strings when debugging.
class TicTacToe:
def __init__(self):
# Initialise an empty board
self.board = ['-' for _ in range(9)]
self.current_player = 'X' # X will start
def make_move(self, position):
"""Make a move on the board."""
if self.board[position] == '-':
self.board[position] = self.current_player
self.switch_player()
return True
else: return False # illegal move
def switch_player(self):
"""Switch the current player."""
self.current_player = 'O' if self.current_player == 'X' else 'X'
def check_winner(self):
"""Check if there is a winner."""
# Rows, columns, diagonals
winning_positions = [
[0, 1, 2], [3, 4, 5], [6, 7, 8], # Rows
[0, 3, 6], [1, 4, 7], [2, 5, 8], # Columns
[0, 4, 8], [2, 4, 6] # Diagonals
]
for positions in winning_positions:
values = [self.board[pos] for pos in positions]
if values[0] == values[1] == values[2] and values[0] != '-':
return values[0]
return None # No winner yet
def is_draw(self):
"""Check if the game is a draw."""
return all(cell != '-' for cell in self.board)
def get_board_string(self):
"""Get the current board state as a string."""
return ''.join(self.board)
def get_legal_moves(self):
"""Get the positions of all legal moves."""
return [i for i, cell in enumerate(self.board) if cell == '-']
def pretty_print_board(self):
"""Pretty-print the board."""
for i in range(0, 9, 3):
print(f"{self.board[i]} | {self.board[i+1]} | {self.board[i+2]}")
if i < 6:
print("- "*5)
# Test the pretty_print_board method
tic_tac_toe = TicTacToe()
print("Initial board:")
tic_tac_toe.pretty_print_board()
# Make some moves
tic_tac_toe.make_move(0)
tic_tac_toe.make_move(4)
tic_tac_toe.make_move(8)
print("nBoard after some moves:")
tic_tac_toe.pretty_print_board()
Initial board:
- | - | -
- - - - -
- | - | -
- - - - -
- | - | -
Board after some moves:
X | - | -
- - - - -
- | O | -
- - - - -
- | - | X
Creating our training data
We want our transformer to take in a given board state and output a move, that move being an integer from 0 to 8 representing the square it wants to place its piece in. To do so, we'll create a dataset of board states and moves. We'll do this by simulating all the possible winning positions of our player, and then iterating over all the combinations of games that could have got us there. This means the transformer will learn what constitutes a good move in any given board state.
To implement this, the simulate_all_games
function generates both the training and validation data. Specifically, the function simulates all possible Tic-Tac-Toe games, exploring every permutation of moves for both players (‘X' and ‘O'). This exhaustive simulation ensures that the model is trained on a comprehensive dataset, encompassing all conceivable game scenarios. During each simulated game, the function records not just the winning or drawing outcomes but also the sequence of board states and the moves that led to those outcomes. These board states and moves are then transformed into numerical representations, suitable for training our transformer. This ensures that the model learns not just to win, but to output an appropriate winning move from any given board state. You can think of this as akin to the ability of a regular language transformer to output an appropriate token given any length of context, from just one token (i.e. our starting board state) to the EOS token (i.e. our winning move yielding the final board state).
from copy import deepcopy
from itertools import product
import numpy as np
import torch
# Define character to integer mapping
chars = sorted(list(set('XO-')))
vocab_size = len(chars)
stoi = {ch: i for i, ch in enumerate(chars)}
itos = {0: '-', 1: 'X', 2: 'O'}
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join([itos[i] for i in l])
input_sequences = []
output_sequences = []
# Function to simulate all possible games recursively
def simulate_all_games(game, x_moves, o_moves):
global input_sequences, output_sequences
# Check if the game has reached a terminal state
winner = game.check_winner()
if winner == 'X' or game.is_draw():
# Add the sequence of board states and moves leading to this win for 'X' or draw
board = ['-' for _ in range(9)]
for i, x_move in enumerate(x_moves):
input_sequences.append(encode(''.join(board)))
output_sequences.append(x_move)
board[x_move] = 'X'
if i < len(o_moves):
board[o_moves[i]] = 'O'
return
elif winner == 'O':
return # We don't add these to our training data
# Otherwise, continue simulating the game
legal_moves = game.get_legal_moves()
for move in legal_moves:
# Create a copy of the game to simulate the move
new_game = deepcopy(game)
was_legal = new_game.make_move(move)
# If the move was legal, continue simulating
if was_legal:
if new_game.current_player == 'X':
simulate_all_games(new_game, x_moves + [move], o_moves)
else:
simulate_all_games(new_game, x_moves, o_moves + [move])
# Create an initial empty game
initial_game = TicTacToe()
# Simulate all possible games starting with 'X'
simulate_all_games(initial_game, [], [])
# Convert to PyTorch tensors
input_tensor = torch.tensor(input_sequences, dtype=torch.long)
output_tensor = torch.tensor(output_sequences, dtype=torch.long)
# Show some sample input-output pairs
print(input_tensor[:10], output_tensor[:10])
print("Number of input-output pairs:", len(input_sequences))
tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 2, 0, 0, 0, 0, 0, 0, 0],
[1, 2, 1, 2, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 2, 0, 0, 0, 0, 0, 0, 0],
[1, 2, 1, 2, 0, 0, 0, 0, 0],
[1, 2, 1, 2, 1, 2, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 2, 0, 0, 0, 0, 0, 0, 0],
[1, 2, 1, 2, 0, 0, 0, 0, 0]])
tensor([1, 3, 5, 1, 3, 5, 6, 1, 3, 5])
Number of input-output pairs: 658224
This gives us about 650,000 tensors to train on. These tensors look roughly correct, but it's a bit hard to tell without seeing the board visually. Let's recycle our print_board
function to see what the some random board states, and the next board state given the move, look like:
def pretty_print_board(board: str):
"""Pretty-print the board."""
for i in range(0, 9, 3):
print(f"{board[i]} | {board[i+1]} | {board[i+2]}")
if i < 6:
print("- "*5)
rand_idx = torch.randint(len(input_tensor), (1,))[0]
random_game = input_tensor[rand_idx].tolist()
print("Current game state:")
decoded_game = decode(random_game)
pretty_print_board(decoded_game)
print( )
move = output_tensor[rand_idx].item()
decoded_game = decoded_game[:move] + 'X' + decoded_game[move+1:]
print("New game state:")
pretty_print_board(decoded_game)
Current game state:
X | - | -
- - - - -
- | X | O
- - - - -
- | O | -
New game state:
X | - | X
- - - - -
- | X | O
- - - - -
- | O | -
This seems plausible, but I've noticed that some games have a winning move available, and yet the simulation makes a different move (that still ends up in a win). This occurred in the example above. Let's change the simulate_all_games
function to stop searching once we get to at least one potential winning move.
input_sequences = []
output_sequences = []
def simulate_all_games(game, x_moves, o_moves):
global input_sequences, output_sequences
# Check if the game has reached a terminal state
winner = game.check_winner()
if winner == 'X' or game.is_draw():
# Add the sequence of board states and moves leading to this win for 'X' or draw
board = ['-' for _ in range(9)]
for i, x_move in enumerate(x_moves):
input_sequences.append(encode(''.join(board)))
output_sequences.append(x_move)
board[x_move] = 'X'
if i < len(o_moves):
board[o_moves[i]] = 'O'
return
elif winner == 'O':
return # We don't add these to our training data
# Before simulating further moves, check if a winning move is available
legal_moves = game.get_legal_moves()
for move in legal_moves:
test_game = deepcopy(game)
test_game.make_move(move)
if test_game.check_winner() == game.current_player:
# This move is a winning move, so we make it and end further simulation
if test_game.current_player == 'X':
simulate_all_games(test_game, x_moves + [move], o_moves)
else:
simulate_all_games(test_game, x_moves, o_moves + [move])
return # End further exploration for this branch
# If no immediate winning move is found, continue simulating the game
for move in legal_moves:
# Create a copy of the game to simulate the move
new_game = deepcopy(game)
was_legal = new_game.make_move(move)
# If the move was legal, continue simulating
if was_legal:
if new_game.current_player == 'X':
simulate_all_games(new_game, x_moves + [move], o_moves)
else:
simulate_all_games(new_game, x_moves, o_moves + [move])
# Create an initial empty game
initial_game = TicTacToe()
# Simulate all possible games starting with 'X'
simulate_all_games(initial_game, [], [])
# Convert to PyTorch tensors
input_tensor = torch.tensor(input_sequences, dtype=torch.long)
output_tensor = torch.tensor(output_sequences, dtype=torch.long)
Finally, let's see how many moves we have to train on:
print("Number of input-output pairs:", len(input_sequences))
Number of input-output pairs: 147104
About 150,000 examples. This seems like a reasonable start.
Transformer architecture with multi-head attention
Attention is a mechanism whereby the model can learn to focus on certain parts of the input sequence when making predictions. The transformer architecture uses multi-head self-attention, which means that the model learns to attend to different parts of the input sequence in different ways. This is useful because it allows the model to learn different relationships between the input and output sequences. For example, it might learn to attend to the first token in the input sequence when predicting the first token in the output sequence, but attend to the last token in the input sequence when predicting the second token in the output sequence. This is a powerful mechanism that allows the model to learn complex relationships between the input and output sequences.
But how does this actually work? From the original Attention is all you need paper, attention defined over a query matrix Q, a key matrix K, and a value matrix V is defined as:

where we divide by sqrt{d_k} to ensure an appropriate variance in the softmax. Let's break down what's actually happening here. Suppose we have an input of dimension (B,T,C) where B is the batch-size, T is the sequence length, and C is the number of channels. We can think of this as a batch of B sequences of length T, each with C channels:
import torch.nn as nn
import torch.nn.functional as F
torch.manual_seed(1337)
B,T,C = 4,8,32 # batch, time, channels
x = torch.randn(B,T,C)
Then, to implement a single head of self-attention, we need to create the query, the key, and the value. In reality, these are linear layers with a certain head_size
, which is just how wide we want the linear layer to be. We don't include a bias term because we don't want to learn a bias term for the attention.
wei = q @ k.transpose(-2,-1) # (B, T, 16) @ (B, 16, T) -> (B, T, T)
However, if you break down what we're actually doing, we're predicting the next word in a sequence of words. Since we don't want to cheat and use parts of the sequence we haven't seen yet (since we can't do this during generation), we need to mask out the parts of the sequence we haven't seen yet. We do this by creating a mask of shape (T,T) where T is the sequence length, and then setting all the values in the upper triangle to negative infinity. This ensures that the softmax will be 0 for all the masked values, and so the model will not attend to them.
Finally, we multiply the attention weights by the value matrix to get the output of the attention layer. This is the output of a single head of self-attention. We can then repeat this process for as many heads as we want, and then concatenate the outputs of each head to get the final output of the multi-head self-attention layer.
tril = torch.tril(torch.ones((T,T)))
wei = wei.masked_fill(tril==0, float("-inf"))
wei = F.softmax(wei, dim=-1)
v = value(x) # (B,T,16)
out = wei @ v
out.shape
And that's a single head of self-attention. To create multiple heads, we simply repeat this process multiple times, and then concatenate the outputs of each head to get the final output of the multi-head self-attention layer. We also add residual connections in order to improve our ability to optimise this relatively deep model. For a full walkthrough of similar code and the mechanisms behind decoder-only transformers, I highly recommend Andrej Karpathy's nanoGPT lecture.
from tqdm import tqdm
import torch.nn as nn
import torch.nn.functional as F
# Hyperparameters
batch_size = 128 # How many independent sequences will we process in parallel?
block_size = 9 # The size of the tic-tac-toe board
max_iters = 10000
eval_interval = 500
learning_rate = 1e-3
device = 'mps' if torch.backends.mps.is_available() else 'cpu'
eval_iters = 100
n_embd = 32 # Reduced the embedding size
n_head = 2 # Reduced the number of heads
n_layer = 2 # Reduced the number of layers
dropout = 0.1
print(f'Training on {device}')
# Initialize random seed
torch.manual_seed(1337)
# Split into training and validation sets
n = int(0.90 * len(input_tensor)) # 90% for training
train_input = input_tensor[:n]
train_output = output_tensor[:n]
val_input = input_tensor[n:]
val_output = output_tensor[n:]
# Updated data loading function
def get_batch(split):
input_data = train_input if split == 'train' else val_input
output_data = train_output if split == 'train' else val_output
# Choose index locs for batch_size sequences
ix = torch.randint(len(input_data) - block_size + 1, (batch_size,))
# Get the input and output sequences
x = input_data[ix]
y = output_data[ix]
x, y = x.to(device), y.to(device)
return x, y
@torch.no_grad()
def estimate_loss():
out = {}
model.eval()
for split in ['train', 'val']:
losses = torch.zeros(eval_iters)
for k in range(eval_iters):
X, Y = get_batch(split)
logits, loss = model(X, Y)
losses[k] = loss.item()
out[split] = losses.mean()
model.train()
return out
class Head(nn.Module):
""" one head of self-attention """
def __init__(self, head_size):
super().__init__()
self.key = nn.Linear(n_embd, head_size, bias=False)
self.query = nn.Linear(n_embd, head_size, bias=False)
self.value = nn.Linear(n_embd, head_size, bias=False)
self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
self.dropout = nn.Dropout(dropout)
def forward(self, x):
# input of size (batch, time-step, channels)
# output of size (batch, time-step, head size)
B,T,C = x.shape
k = self.key(x) # (B,T,hs)
q = self.query(x) # (B,T,hs)
# compute attention scores ("affinities")
wei = q @ k.transpose(-2,-1) * k.shape[-1]**-0.5 # (B, T, hs) @ (B, hs, T) -> (B, T, T)
wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)
wei = F.softmax(wei, dim=-1) # (B, T, T)
wei = self.dropout(wei)
# perform the weighted aggregation of the values
v = self.value(x) # (B,T,hs)
out = wei @ v # (B, T, T) @ (B, T, hs) -> (B, T, hs)
return out
class MultiHeadAttention(nn.Module):
""" multiple heads of self-attention in parallel """
def __init__(self, num_heads, head_size):
super().__init__()
self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
self.proj = nn.Linear(head_size * num_heads, n_embd)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
out = torch.cat([h(x) for h in self.heads], dim=-1)
out = self.dropout(self.proj(out))
return out
class FeedFoward(nn.Module):
""" a simple linear layer followed by a non-linearity """
def __init__(self, n_embd):
super().__init__()
self.net = nn.Sequential(
nn.Linear(n_embd, 4 * n_embd),
nn.ReLU(),
nn.Linear(4 * n_embd, n_embd),
nn.Dropout(dropout),
)
def forward(self, x):
return self.net(x)
class Block(nn.Module):
""" Transformer block: communication followed by computation """
def __init__(self, n_embd, n_head):
# n_embd: embedding dimension, n_head: the number of heads we'd like
super().__init__()
head_size = n_embd // n_head
self.sa = MultiHeadAttention(n_head, head_size)
self.ffwd = FeedFoward(n_embd)
self.ln1 = nn.LayerNorm(n_embd)
self.ln2 = nn.LayerNorm(n_embd)
def forward(self, x):
x = x + self.sa(self.ln1(x))
x = x + self.ffwd(self.ln2(x))
return x
class Transformer(nn.Module):
def __init__(self):
super().__init__()
# each token directly reads off the logits for the next token from a lookup table
self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
self.position_embedding_table = nn.Embedding(block_size, n_embd)
self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)])
self.ln_f = nn.LayerNorm(n_embd) # final layer norm
self.lm_head = nn.Linear(n_embd, 9)
# better init, not covered in the original GPT video, but important, will cover in followup video
self.apply(self._init_weights)
def _init_weights(self, module):
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
def forward(self, idx, targets=None):
B, T = idx.shape
# idx and targets are both (B,T) tensor of integers
tok_emb = self.token_embedding_table(idx) # (B,T,C)
pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,C)
x = tok_emb + pos_emb # (B,T,C)
x = self.blocks(x) # (B,T,C)
x = self.ln_f(x) # (B,T,C)
logits = self.lm_head(x) # (B,T,vocab_size)
# Take the logits corresponding to the last time step T
logits = logits[:, -1, :] # Now logits is (B, 9)
if targets is None:
loss = None
else:
loss = F.cross_entropy(logits, targets)
return logits, loss
def generate(self, idx, max_new_tokens):
# idx is (B, T) array of indices in the current context
for _ in range(max_new_tokens):
# crop idx to the last block_size tokens
idx_cond = idx[:, -block_size:]
# get the predictions
logits, loss = self(idx_cond)
# focus only on the last time step
logits = logits[:, -1, :] # becomes (B, C)
# apply softmax to get probabilities
probs = F.softmax(logits, dim=-1) # (B, C)
# sample from the distribution
idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
# append sampled index to the running sequence
idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
return idx
To make sure our architecture works as intended, let's try passing a single batch through.
xb, yb = get_batch('train')
print(xb.shape, yb.shape)
m = Transformer().to(device)
logits, loss = m(xb, yb)
print(logits.shape)
print(f"Loss: {loss.item():.3f}")
torch.Size([128, 9]) torch.Size([128])
torch.Size([128, 9])
Loss: 2.203
A good step when doing this initial forward pass is to test whether the loss is about equivalent to what we'd expect for random inputs. Since we have 9-dimensional logits and we're using cross-entropy loss, which is equivalent to the negative log-likelihood of the correct class, we'd expect the loss to be about:

Training the model
Using quite a small transformer (about 25,000 parameters), we achieve the following loss (note that I'm using a small amount of weight decay and dropout):
model = Transformer()
model = model.to(device)
# Print the number of parameters in the model
print(sum(p.numel() for p in model.parameters()) / 1e6, 'M parameters')
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=1e-4)
train_loss_history = []
val_loss_history = []
# Training loop
for iter in tqdm(range(max_iters)):
# Evaluate the loss on train and val sets occasionally
if iter % eval_interval == 0 or iter == max_iters - 1:
losses = estimate_loss()
print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
val_loss_history.append(losses['val'])
# Sample a batch of data
xb, yb = get_batch('train')
# Evaluate the loss
logits, loss = model(xb, yb)
train_loss_history.append(loss.item())
optimizer.zero_grad()
loss.backward()
optimizer.step()
0.025961 M parameters
0%| | 5/10000 [00:00<24:42, 6.74it/s]
step 0: train loss 2.2033, val loss 2.2106
5%|▌ | 504/10000 [00:14<12:01, 13.15it/s]
step 500: train loss 1.9162, val loss 2.0215
10%|█ | 1008/10000 [00:27<08:27, 17.73it/s]
step 1000: train loss 1.7846, val loss 1.8570
15%|█▌ | 1505/10000 [00:40<10:34, 13.39it/s]
step 1500: train loss 1.7370, val loss 1.7648
20%|██ | 2007/10000 [00:53<07:35, 17.55it/s]
step 2000: train loss 1.7188, val loss 1.7770
25%|██▌ | 2506/10000 [01:05<07:11, 17.36it/s]
step 2500: train loss 1.6957, val loss 1.7456
30%|███ | 3006/10000 [01:18<06:35, 17.69it/s]
step 3000: train loss 1.6965, val loss 1.7448
35%|███▌ | 3506/10000 [01:31<06:12, 17.41it/s]
step 3500: train loss 1.6961, val loss 1.7809
40%|████ | 4005/10000 [01:43<07:41, 12.98it/s]
step 4000: train loss 1.6819, val loss 1.7256
45%|████▌ | 4506/10000 [01:56<05:18, 17.24it/s]
step 4500: train loss 1.6892, val loss 1.7066
50%|█████ | 5005/10000 [02:09<05:14, 15.88it/s]
step 5000: train loss 1.6846, val loss 1.7141
55%|█████▌ | 5508/10000 [02:23<04:37, 16.19it/s]
step 5500: train loss 1.6835, val loss 1.6998
60%|██████ | 6004/10000 [02:36<05:19, 12.51it/s]
step 6000: train loss 1.6828, val loss 1.7095
65%|██████▌ | 6506/10000 [02:49<03:23, 17.13it/s]
step 6500: train loss 1.6722, val loss 1.7151
70%|███████ | 7008/10000 [03:02<03:05, 16.17it/s]
step 7000: train loss 1.6656, val loss 1.7158
75%|███████▌ | 7505/10000 [03:15<02:30, 16.54it/s]
step 7500: train loss 1.6672, val loss 1.7078
80%|████████ | 8007/10000 [03:28<02:01, 16.38it/s]
step 8000: train loss 1.6808, val loss 1.7120
85%|████████▌ | 8505/10000 [03:41<01:47, 13.94it/s]
step 8500: train loss 1.6733, val loss 1.7144
90%|█████████ | 9007/10000 [03:54<00:56, 17.54it/s]
step 9000: train loss 1.6714, val loss 1.7031
95%|█████████▌| 9506/10000 [04:07<00:28, 17.39it/s]
step 9500: train loss 1.6707, val loss 1.7073
100%|██████████| 10000/10000 [04:20<00:00, 38.43it/s]
step 9999: train loss 1.6664, val loss 1.7506
That doesn't seem very good. Let's plot it to see what's going on.
import matplotlib.pyplot as plt
def plot_transformer_loss(loss_history, val_loss_history):
# Two horizontal figures side-by-side
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
# First plot = training loss
ax1.plot(loss_history, lw=0.5)
# Plot moving average of loss
window_size = 100
ax1.plot(np.convolve(loss_history, np.ones(window_size) / window_size, mode='valid'), label='Moving average')
ax1.set_xlabel('Iteration')
ax1.set_ylabel('Cross-entropy Loss')
ax1.set_title('Training Loss')
ax1.legend()
# Second plot = validation loss
# Set marker style to be circles at each data point
indices = np.arange(0, len(val_loss_history) * eval_interval, eval_interval)
ax2.plot(indices, val_loss_history, marker='o')
ax2.set_title('Validation Loss')
ax2.set_xlabel('Iteration')
plt.show()
plot_transformer_loss(train_loss_history, val_loss_history)

We can play the transformer using this code to see if it's any good:
import random
import torch
from IPython.display import clear_output
def play_game(model, stoi, itos, device):
game = TicTacToe()
# Randomly decide who goes first
game.current_player = random.choice(['X', 'O'])
while game.check_winner() is None and not game.is_draw():
#clear_output(wait=True)
print(f"{game.current_player}'s turn.")
game.pretty_print_board()
current_board_str = game.get_board_string()
if game.current_player == 'X':
print("Model's turn...")
current_board_encoded = torch.tensor([stoi[c] for c in current_board_str], dtype=torch.long).unsqueeze(0).to(device)
logits, _ = model(current_board_encoded)
# Move logits to cpu
logits = logits.cpu()
# Create a mask for legal moves and zero out logits for illegal moves
legal_moves = game.get_legal_moves()
mask = torch.zeros(9)
mask[legal_moves] = 1
masked_logits = logits * mask
# Get the model's move
predicted_move = masked_logits.argmax(dim=-1).item()
# Make the model's move
game.make_move(predicted_move)
else:
print("Your turn!")
legal_moves = game.get_legal_moves()
print("Legal moves:", legal_moves)
user_move = int(input("Enter your move: "))
if user_move in legal_moves:
game.make_move(user_move)
else:
print("Illegal move. Try again.")
continue
winner = game.check_winner()
if winner is not None:
#clear_output(wait=True)
print(f"{winner} wins!")
game.pretty_print_board()
break
elif game.is_draw():
#clear_output(wait=True)
print("It's a draw!")
game.pretty_print_board()
break
O's turn.
- | - | -
- - - - -
- | - | -
- - - - -
- | - | -
Your turn!
Legal moves: [0, 1, 2, 3, 4, 5, 6, 7, 8]
X's turn.
- | - | -
- - - - -
- | - | -
- - - - -
- | - | O
Model's turn...
O's turn.
- | - | X
- - - - -
- | - | -
- - - - -
- | - | O
Your turn!
Legal moves: [0, 1, 3, 4, 5, 6, 7]
X's turn.
O | - | X
- - - - -
- | - | -
- - - - -
- | - | O
Model's turn...
O's turn.
O | - | X
- - - - -
X | - | -
- - - - -
- | - | O
Your turn!
Legal moves: [1, 4, 5, 6, 7]
O wins!
O | - | X
- - - - -
X | O | -
- - - - -
- | - | O
Okay, so I easily beat the model. Something is going wrong.
Improving the transformer
So at the moment, the transformer can't even reliably learn the simple winning moves from any given position. I can think of a few reasons for this:
- The transformer is only training on the winning moves, so maybe it's not learning how to play when I play a decent strategy that means no winning moves are available. Theoretically, to counter this we should allow it to train on moves when the game is pre-destined to be a draw.
- The transformer is overparametrised. Trying to get a few hundred thousand neurons to coordinate a simple strategy may take a long time to train and rely on grokking and other phenomenon to get into a generalisable part of the optimisation landscape.
- The transformer is underparametrised. Maybe it needs more neurons to learn a good strategy. Saying this, I very much doubt that scaling up to a few million neurons will help if a few hundred thousand doesn't do the job.
State-space analysis
Before I move on, I want to analyse the number of neurons theoretically required to encode the entire winning strategy for tic-tac-toe, we need to consider the state space of the game and the complexity of the decision-making process.
In tic-tac-toe, the game board is a 3 x 3 grid, and each cell can be in one of three states: ‘X', ‘O', or empty (‘-‘). So, the total number of possible board states can be calculated as 3⁹ = 19683. However, not all of these are valid states in a real game; some of them are unreachable or illegal (e.g., a board with all ‘X's). The number of legal states is actually around 5,478, but for the purpose of this analysis, we'll consider the upper limit, i.e., 3⁹.
Each state requires a decision: where to place the next ‘X' (since we're considering a winning strategy for ‘X'). There are 9 possible positions, but the number of legal moves is often less than 9, depending on how many cells are already occupied. A neural network would have to map each possible board state to a correct move. One way to achieve this mapping is through a fully connected layer that takes the board state as input and outputs a probability distribution over the 9 possible moves. The input layer would have 3×3=9 neurons (one for each cell), and the output layer would have 9 neurons (one for each possible move). The hidden layers in between would perform the complex task of learning the winning strategy.
Given that the input layer has 9 neurons and the output layer has 9 neurons, the number of neurons in the hidden layers is what we're interested in. Theoretically, we could use a single hidden layer with 3⁹ neurons to map each possible state to a winning move. This would be an upper bound and is likely far more than what would actually be necessary due to the presence of unreachable/illegal states and the inherent symmetries in tic-tac-toe that reduce the actual number of unique states. So, in this upper-bound scenario, the total number of neurons would be: 9 + 19683 + 9 = 19701. This is a theoretical upper bound and the actual number could be much lower due to the factors mentioned earlier. Let's try implementing a simple feed-forward vanilla neural network to see how it does on our task. Rather than having one layer with thousands of neurons, we'll use three hidden layers.
import torch.nn as nn
import torch.nn.functional as F
device = 'cpu'
class TicTacToeNN(nn.Module):
def __init__(self):
super(TicTacToeNN, self).__init__()
self.fc1 = nn.Linear(9, 16) # Input layer to hidden layer 1
self.fc2 = nn.Linear(16, 32) # Hidden layer 1 to hidden layer 2
self.fc3 = nn.Linear(32, 16) # Hidden layer 2 to hidden layer
self.fc4 = nn.Linear(16, 9) # Hidden layer 3 to output layer
def forward(self, x):
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = F.relu(self.fc3(x))
x = self.fc4(x)
return x
def get_batch(split):
input_data = train_input if split == 'train' else val_input
output_data = train_output if split == 'train' else val_output
# Choose index locs for batch_size sequences
ix = torch.randint(len(input_data) - block_size + 1, (batch_size,))
# Get the input and output sequences
x = input_data[ix].float()
y = output_data[ix]
x, y = x.to(device), y.to(device)
return x, y
# Create an initial empty game
initial_game = TicTacToe()
# Simulate all possible games starting with 'X'
simulate_all_games(initial_game, [], [])
# Convert to PyTorch tensors
input_tensor = torch.tensor(input_sequences, dtype=torch.long)
output_tensor = torch.tensor(output_sequences, dtype=torch.long)
nn_model = TicTacToeNN()
nn_model.to(device)
# Print the number of parameters in the model
print(sum(p.numel() for p in nn_model.parameters()), 'parameters')
# Create a PyTorch optimizer
optimizer = torch.optim.AdamW(nn_model.parameters(), lr=learning_rate, weight_decay=1e-4)
train_loss_history = []
val_loss_history = []
# Training loop
max_iters = 1000000
for iter in tqdm(range(max_iters)):
# Evaluate the loss on train and val sets occasionally
# Sample a batch of data
xb, yb = get_batch('train')
# Evaluate the loss
logits = nn_model(xb)
# Calculate cross-entropy loss
loss = F.cross_entropy(logits, yb)
train_loss_history.append(loss.item())
# Get the validation loss
xb, yb = get_batch('val')
logits = nn_model(xb)
val_loss = F.cross_entropy(logits, yb)
val_loss_history.append(val_loss.item())
# Backpropagate and update the weights
optimizer.zero_grad()
loss.backward()
optimizer.step()
1385 parameters
100%|██████████| 1000000/1000000 [08:08<00:00, 2048.42it/s]
And let's see what the loss looks like:

Clearly, we're saturating performance. There's something about the task and how we've set it up that prevents the models from learning an appropriate strategy. To shift tack a bit, I'm going to try giving the model training data which just consists of the optimal strategy.
Optimal strategy training data
Newell and Simon's 1972 tic-tac-toe program outlines the perfect strategy (to win or at least draw) if we choose the first available move from the following preference over moves:
- Win: If you have two in a row, play the third to get three in a row.
- Block: If the opponent has two in a row, play the third to block them.
- Fork: Create an opportunity where you can win in two ways.
- Block Opponent's Fork: we either create two in a row to force a defence (if this doesn't result in a fork for them), or we block their potential fork.
- Centre: Play the centre.
- Opposite Corner: If the opponent is in the corner, play the opposite corner.
- Empty Corner: Play an empty corner.
- Empty Side: Play an empty side.
Let's rewrite our data generator to obtain all possible moves according to this strategy. We'll also simulate all games for both possible players going first.
from copy import deepcopy
# Helper function to find if there's a winning move or a move that blocks the opponent from winning
def find_winning_or_blocking_move(board, player):
winning_positions = [
[0, 1, 2], [3, 4, 5], [6, 7, 8], # Rows
[0, 3, 6], [1, 4, 7], [2, 5, 8], # Columns
[0, 4, 8], [2, 4, 6] # Diagonals
]
for positions in winning_positions:
values = [board[pos] for pos in positions]
if values.count(player) == 2 and values.count('-') == 1:
return positions[values.index('-')]
return None
# Helper function for checking for fork opportunities
def find_fork_move(board, player):
fork_move = None
for i in range(9):
if board[i] == '-':
temp_board = board[:]
temp_board[i] = player
winning_moves = 0
for j in range(9):
if temp_board[j] == '-':
temp_board_2 = temp_board[:]
temp_board_2[j] = player
if find_winning_or_blocking_move(temp_board_2, player) is not None:
winning_moves += 1
if winning_moves >= 2:
fork_move = i
break
return fork_move
# Helper function to find the optimal move according to a defined strategy
def optimal_strategy(board, player):
opponent = 'O' if player == 'X' else 'X'
# 1. Win: If you have two in a row, play the third to get three in a row.
win_move = find_winning_or_blocking_move(board, player)
if win_move is not None:
return win_move
# 2. Block: If the opponent has two in a row, play the third to block them.
block_move = find_winning_or_blocking_move(board, opponent)
if block_move is not None:
return block_move
# 3. Fork: Create an opportunity where you can win in two ways.
fork_move = find_fork_move(board, player)
if fork_move is not None:
return fork_move
# 4. Block Opponent's Fork
opponent_fork_move = find_fork_move(board, opponent)
if opponent_fork_move is not None:
return opponent_fork_move
# 5. Center: Play the center.
if board[4] == '-':
return 4
# 6. Opposite Corner: If the opponent is in the corner, play the opposite corner.
corners = [(0, 8), (2, 6), (8, 0), (6, 2)]
for corner1, corner2 in corners:
if board[corner1] == opponent and board[corner2] == '-':
return corner2
# 7. Empty Corner: Play an empty corner.
for corner in [0, 2, 6, 8]:
if board[corner] == '-':
return corner
# 8. Empty Side: Play an empty side.
for side in [1, 3, 5, 7]:
if board[side] == '-':
return side
# Function to simulate all games according to the optimal strategy
def simulate_all_games_optimal_v2(game, x_starts=True):
global input_sequences, output_sequences
# Check for terminal state
winner = game.check_winner()
if winner or game.is_draw():
return
# If it's X's turn, apply the optimal strategy and save the board state and move
if game.current_player == 'X':
move = optimal_strategy(game.board, 'X')
if move is None:
move = game.get_legal_moves()[0] # fallback
input_sequences.append(encode(''.join(game.board)))
output_sequences.append(move)
new_game = deepcopy(game)
new_game.make_move(move)
simulate_all_games_optimal_v2(new_game, x_starts)
else:
# If it's O's turn, explore all possible legal moves
for move in game.get_legal_moves():
new_game = deepcopy(game)
new_game.make_move(move)
simulate_all_games_optimal_v2(new_game, x_starts)
# Character to integer mapping
chars = sorted(list(set('XO-')))
vocab_size = len(chars)
stoi = {ch: i for i, ch in enumerate(chars)}
itos = {0: '-', 1: 'X', 2: 'O'}
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join([itos[i] for i in l])
# Reset and re-simulate
input_sequences = []
output_sequences = []
# 'X' starts
initial_game = TicTacToe()
simulate_all_games_optimal_v2(initial_game, True)
# 'O' starts
initial_game = TicTacToe()
initial_game.current_player = 'O'
simulate_all_games_optimal_v2(initial_game, False)
# Convert to Pytorch tensors
input_tensor = torch.tensor(input_sequences, dtype=torch.long)
output_tensor = torch.tensor(output_sequences, dtype=torch.long)
print("Number of input-output pairs:", len(input_sequences))
Number of input-output pairs: 1017
Let's retrain our model on our new training data.
# Hyperparameters
batch_size = 128 # How many independent sequences will we process in parallel?
block_size = 9 # The size of the tic-tac-toe board
max_iters = 10000
eval_interval = 500
learning_rate = 1e-3
device = 'mps' if torch.backends.mps.is_available() else 'cpu'
eval_iters = 100
n_embd = 32 # Reduced the embedding size
n_head = 2 # Reduced the number of heads
n_layer = 2 # Reduced the number of layers
dropout = 0.1
print(f'Training on {device}')
# Initialize random seed
torch.manual_seed(1337)
# Split into training and validation sets
n = int(0.90 * len(input_tensor)) # 90% for training
train_input = input_tensor[:n]
train_output = output_tensor[:n]
val_input = input_tensor[n:]
val_output = output_tensor[n:]
# Updated data loading function
def get_batch(split):
input_data = train_input if split == 'train' else val_input
output_data = train_output if split == 'train' else val_output
# Choose index locs for batch_size sequences
ix = torch.randint(len(input_data) - block_size + 1, (batch_size,))
# Get the input and output sequences
x = input_data[ix]
y = output_data[ix]
x, y = x.to(device), y.to(device)
return x, y
# Initialize the model
model = Transformer()
model = model.to(device)
max_iters = 5000
# Print the number of parameters in the model
print(sum(p.numel() for p in model.parameters()) / 1e6, 'M parameters')
# Create a PyTorch optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=1e-4)
train_loss_history = []
val_loss_history = []
# Training loop
for iter in tqdm(range(max_iters)):
# Evaluate the loss on train and val sets occasionally
if iter % eval_interval == 0 or iter == max_iters - 1:
losses = estimate_loss()
print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
val_loss_history.append(losses['val'])
# Sample a batch of data
xb, yb = get_batch('train')
# Evaluate the loss
logits, loss = model(xb, yb)
train_loss_history.append(loss.item())
optimizer.zero_grad()
loss.backward()
optimizer.step()
Training on mps
0.025961 M parameters
0%| | 6/5000 [00:00<10:15, 8.11it/s]
step 0: train loss 2.2078, val loss 2.2166
10%|█ | 505/5000 [00:13<05:23, 13.89it/s]
step 500: train loss 0.3063, val loss 0.6145
20%|██ | 1005/5000 [00:26<04:40, 14.24it/s]
step 1000: train loss 0.0741, val loss 0.2259
30%|███ | 1505/5000 [00:38<04:08, 14.05it/s]
step 1500: train loss 0.0368, val loss 0.1799
40%|████ | 2005/5000 [00:51<03:36, 13.83it/s]
step 2000: train loss 0.0134, val loss 0.1589
50%|█████ | 2504/5000 [01:04<02:57, 14.06it/s]
step 2500: train loss 0.0081, val loss 0.0884
60%|██████ | 3008/5000 [01:17<01:56, 17.06it/s]
step 3000: train loss 0.0041, val loss 0.0521
70%|███████ | 3505/5000 [01:29<01:46, 14.09it/s]
step 3500: train loss 0.0028, val loss 0.0855
80%|████████ | 4005/5000 [01:42<01:10, 14.06it/s]
step 4000: train loss 0.0036, val loss 0.1125
90%|█████████ | 4506/5000 [01:56<00:29, 16.68it/s]
step 4500: train loss 0.0014, val loss 0.0892
100%|██████████| 5000/5000 [02:08<00:00, 38.79it/s]
step 4999: train loss 0.0026, val loss 0.0721

So much better! We've not only learned the strategy but it generalises to the validation dataset (we get close to 0 loss on both train and val). I'm guessing this is due to the inherent symmetries in board states, and the transformer has learned a form of modular arithmetic over the board string to be invariant to board states.
Let's try playing the new transformer:
O's turn.
- | - | -
- - - - -
- | - | -
- - - - -
- | - | -
Your turn!
Legal moves: [0, 1, 2, 3, 4, 5, 6, 7, 8]
X's turn.
- | O | -
- - - - -
- | - | -
- - - - -
- | - | -
Model's turn...
O's turn.
X | O | -
- - - - -
- | - | -
- - - - -
- | - | -
Your turn!
Legal moves: [2, 3, 4, 5, 6, 7, 8]
X's turn.
X | O | -
- - - - -
- | - | -
- - - - -
- | O | -
Model's turn...
O's turn.
X | O | -
- - - - -
- | X | -
- - - - -
- | O | -
Your turn!
Legal moves: [2, 3, 5, 6, 8]
X's turn.
X | O | -
- - - - -
- | X | -
- - - - -
- | O | O
Model's turn...
O's turn.
X | O | -
- - - - -
- | X | -
- - - - -
X | O | O
Your turn!
Legal moves: [2, 3, 5]
X's turn.
X | O | O
- - - - -
- | X | -
- - - - -
X | O | O
Model's turn...
X wins!
X | O | O
- - - - -
X | X | -
- - - - -
X | O | O
It beat me! Used a fantastic fork. It seems our transformer has learned the optimal strategy.
Conclusion
I think the main takeaway here is that transformers are more than capable of learning optimal strategies for games. Whilst a vanilla neural net could likely have learned the same optimal strategy, the dynamic nature of attention means that it might be able to attend to longer sequences representing the game over time. These ideas naturally give rise to applying transformers in reinforcement learning settings. For instance, Janner et al. (2021) used transformers to model distributions over trajectories and a beam search as the planning algorithm.
Another thing I learned from this project was that the process of the human hand-holding the transformer through the optimal strategy obviously doesn't scale, particularly as games get more complex. For instance, Go isn't a "solved" game, so we couldn't do what we did above and provide it with the optimal strategy to train on. Instead, we'd have to use something like self-play to select good sequences to then train the transformer on. I hope to experiment with ideas like this in future.
Finally, there's a ripe field to be explored by looking at hierarchies of prediction and planning in transformers. As Ba et al. (2016) noted, Deep Learning has typically focused on methods to maintain temporary state in activation dynamics, whereas our brain seems to operate by regulating temporary state informatio via medium-term synaptic plasticity. In other words, there should be some form of working memory/prediction that operates between the next-token level and the long-term states in LSTMs for instance. The autoregressive decoder-only transformer architecture that forms the backbone for GPT-3 is a powerful model that can generate realistic-looking text simply by predicting one token into the future. However, if we anthropomorphise intelligence, we know that quick intuitive look-ahead (by one token) does not a genius make. In my mind, it would thus be interesting to try and provide a model with multiple hierarchies of prediction, where the model can learn to predict multiple tokens into the future. This would be akin to the human ability to plan ahead, and would likely be a useful skill for transformers to learn.
What does it mean to look ahead? There a few different avenues to explore here:
- Look ahead in time: Can the transformer learn to predict not the next token in the sequence but the token two, three, or more steps ahead? Is predicting two tokens at once just equivalent to predicting one token, then predicting the next token, or is there some strategic benefit to predicting two tokens at a time? Does it force the transformer to think for longer?
- Look ahead in space: There is evidence that humans produce coarse-grained semantic representations of tasks they are performing, and then use hierarchical modules to "fill in the gaps" in these coarse-grained representations. You might think of this as writing an essay: first you create a skeleton of points, then you fill in the thesis sentences for each paragraph, and finally you flesh out the details. Is it possible that Transformers can learn to do the same thing?
Regardless of whether these are useful questions or not, I hope this post gave some clarity and insight into exactly how we can reshape a problem to be suitable for a transformer. Happy fine-tuning!
References
- Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., & Polosukhin, I. (2017). Attention is all you need. Advances in neural information processing systems, 30.
- Janner, M., Li, Q., & Levine, S. (2021). Offline reinforcement learning as one big sequence modeling problem. Advances in neural information processing systems, 34, 1273–1286.
- Ba, J., Hinton, G. E., Mnih, V., Leibo, J. Z., & Ionescu, C. (2016). Using fast weights to attend to the recent past. Advances in neural information processing systems, 29.
- Andrej Karpathy. Let's build GPT: from scratch, in code, spelled out. https://www.youtube.com/watch?v=kCc8FmEb1nY&t=5076s