Andrej Karpathy Video

Code

Pulling the dataset we will be working on:

  curl https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt -o input.txt

Reading it into python

  with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

Data inspection

print("length of dataset in characters: ", len(text))
print("length of data: ", len(data))
length of dataset in characters:  1115394
length of data:  1115394
  print(text[:1000])
First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor citizens, the patricians good.
What authority surfeits on would relieve us: if they
would yield us but the superfluity, while it were
wholesome, we might guess they relieved us humanely;
but they think we are too dear: the leanness that
afflicts us, the object of our misery, is as an
inventory to particularise their abundance; our
sufferance is a gain to them Let us revenge this with
our pikes, ere we become rakes: for the gods know I
speak this in hunger for bread, not in thirst for revenge.
  chars = sorted(list(set(text)))
  vocab_size = len(chars)
  print(''.join(chars))
  print(vocab_size)

 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
65

Tokeniser

  stoi = { ch:i for i,ch in enumerate(chars) }
  itos = { i:ch for i,ch in enumerate(chars) }
  encode = lambda s: [stoi[c] for c in s]
  # defines function taking in string, outputs list of ints
  decode = lambda l: ''.join([itos[i] for i in l])
  # input: list of integers, outputs string

  print(encode("hello world"))
  print(decode(encode("hello world")))
[46, 43, 50, 50, 53, 1, 61, 53, 56, 50, 42]
hello world
  import torch
  data = torch.tensor(encode(text), dtype=torch.long)
  print(data.shape, data.dtype)
  print(data[:1000])
torch.Size([1115394]) torch.int64
tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 14, 43, 44,
        53, 56, 43,  1, 61, 43,  1, 54, 56, 53, 41, 43, 43, 42,  1, 39, 52, 63,
         1, 44, 59, 56, 58, 46, 43, 56,  6,  1, 46, 43, 39, 56,  1, 51, 43,  1,
        57, 54, 43, 39, 49,  8,  0,  0, 13, 50, 50, 10,  0, 31, 54, 43, 39, 49,
         6,  1, 57, 54, 43, 39, 49,  8,  0,  0, 18, 47, 56, 57, 58,  1, 15, 47,
        58, 47, 64, 43, 52, 10,  0, 37, 53, 59,  1, 39, 56, 43,  1, 39, 50, 50,
         1, 56, 43, 57, 53, 50, 60, 43, 42,  1, 56, 39, 58, 46, 43, 56,  1, 58,
        53,  1, 42, 47, 43,  1, 58, 46, 39, 52,  1, 58, 53,  1, 44, 39, 51, 47,
        57, 46, 12,  0,  0, 13, 50, 50, 10,  0, 30, 43, 57, 53, 50, 60, 43, 42,
         8,  1, 56, 43, 57, 53, 50, 60, 43, 42,  8,  0,  0, 18, 47, 56, 57, 58,
         1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 18, 47, 56, 57, 58,  6,  1, 63,
        53, 59,  1, 49, 52, 53, 61,  1, 15, 39, 47, 59, 57,  1, 25, 39, 56, 41,
        47, 59, 57,  1, 47, 57,  1, 41, 46, 47, 43, 44,  1, 43, 52, 43, 51, 63,
         1, 58, 53,  1, 58, 46, 43,  1, 54, 43, 53, 54, 50, 43,  8,  0,  0, 13,
        50, 50, 10,  0, 35, 43,  1, 49, 52, 53, 61,  5, 58,  6,  1, 61, 43,  1,
        49, 52, 53, 61,  5, 58,  8,  0,  0, 18, 47, 56, 57, 58,  1, 15, 47, 58,
        47, 64, 43, 52, 10,  0, 24, 43, 58,  1, 59, 57,  1, 49, 47, 50, 50,  1,
        46, 47, 51,  6,  1, 39, 52, 42,  1, 61, 43,  5, 50, 50,  1, 46, 39, 60,
        43,  1, 41, 53, 56, 52,  1, 39, 58,  1, 53, 59, 56,  1, 53, 61, 52,  1,
        54, 56, 47, 41, 43,  8,  0, 21, 57,  5, 58,  1, 39,  1, 60, 43, 56, 42,
        47, 41, 58, 12,  0,  0, 13, 50, 50, 10,  0, 26, 53,  1, 51, 53, 56, 43,
         1, 58, 39, 50, 49, 47, 52, 45,  1, 53, 52,  5, 58, 11,  1, 50, 43, 58,
         1, 47, 58,  1, 40, 43,  1, 42, 53, 52, 43, 10,  1, 39, 61, 39, 63,  6,
         1, 39, 61, 39, 63,  2,  0,  0, 31, 43, 41, 53, 52, 42,  1, 15, 47, 58,
        47, 64, 43, 52, 10,  0, 27, 52, 43,  1, 61, 53, 56, 42,  6,  1, 45, 53,
        53, 42,  1, 41, 47, 58, 47, 64, 43, 52, 57,  8,  0,  0, 18, 47, 56, 57,
        58,  1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 35, 43,  1, 39, 56, 43,  1,
        39, 41, 41, 53, 59, 52, 58, 43, 42,  1, 54, 53, 53, 56,  1, 41, 47, 58,
        47, 64, 43, 52, 57,  6,  1, 58, 46, 43,  1, 54, 39, 58, 56, 47, 41, 47,
        39, 52, 57,  1, 45, 53, 53, 42,  8,  0, 35, 46, 39, 58,  1, 39, 59, 58,
        46, 53, 56, 47, 58, 63,  1, 57, 59, 56, 44, 43, 47, 58, 57,  1, 53, 52,
         1, 61, 53, 59, 50, 42,  1, 56, 43, 50, 47, 43, 60, 43,  1, 59, 57, 10,
         1, 47, 44,  1, 58, 46, 43, 63,  0, 61, 53, 59, 50, 42,  1, 63, 47, 43,
        50, 42,  1, 59, 57,  1, 40, 59, 58,  1, 58, 46, 43,  1, 57, 59, 54, 43,
        56, 44, 50, 59, 47, 58, 63,  6,  1, 61, 46, 47, 50, 43,  1, 47, 58,  1,
        61, 43, 56, 43,  0, 61, 46, 53, 50, 43, 57, 53, 51, 43,  6,  1, 61, 43,
         1, 51, 47, 45, 46, 58,  1, 45, 59, 43, 57, 57,  1, 58, 46, 43, 63,  1,
        56, 43, 50, 47, 43, 60, 43, 42,  1, 59, 57,  1, 46, 59, 51, 39, 52, 43,
        50, 63, 11,  0, 40, 59, 58,  1, 58, 46, 43, 63,  1, 58, 46, 47, 52, 49,
         1, 61, 43,  1, 39, 56, 43,  1, 58, 53, 53,  1, 42, 43, 39, 56, 10,  1,
        58, 46, 43,  1, 50, 43, 39, 52, 52, 43, 57, 57,  1, 58, 46, 39, 58,  0,
        39, 44, 44, 50, 47, 41, 58, 57,  1, 59, 57,  6,  1, 58, 46, 43,  1, 53,
        40, 48, 43, 41, 58,  1, 53, 44,  1, 53, 59, 56,  1, 51, 47, 57, 43, 56,
        63,  6,  1, 47, 57,  1, 39, 57,  1, 39, 52,  0, 47, 52, 60, 43, 52, 58,
        53, 56, 63,  1, 58, 53,  1, 54, 39, 56, 58, 47, 41, 59, 50, 39, 56, 47,
        57, 43,  1, 58, 46, 43, 47, 56,  1, 39, 40, 59, 52, 42, 39, 52, 41, 43,
        11,  1, 53, 59, 56,  0, 57, 59, 44, 44, 43, 56, 39, 52, 41, 43,  1, 47,
        57,  1, 39,  1, 45, 39, 47, 52,  1, 58, 53,  1, 58, 46, 43, 51,  1, 24,
        43, 58,  1, 59, 57,  1, 56, 43, 60, 43, 52, 45, 43,  1, 58, 46, 47, 57,
         1, 61, 47, 58, 46,  0, 53, 59, 56,  1, 54, 47, 49, 43, 57,  6,  1, 43,
        56, 43,  1, 61, 43,  1, 40, 43, 41, 53, 51, 43,  1, 56, 39, 49, 43, 57,
        10,  1, 44, 53, 56,  1, 58, 46, 43,  1, 45, 53, 42, 57,  1, 49, 52, 53,
        61,  1, 21,  0, 57, 54, 43, 39, 49,  1, 58, 46, 47, 57,  1, 47, 52,  1,
        46, 59, 52, 45, 43, 56,  1, 44, 53, 56,  1, 40, 56, 43, 39, 42,  6,  1,
        52, 53, 58,  1, 47, 52,  1, 58, 46, 47, 56, 57, 58,  1, 44, 53, 56,  1,
        56, 43, 60, 43, 52, 45, 43,  8,  0,  0])
  n = int(0.9*len(data))
  train_data = data[:n]
  val_data = data[n:]

Understanding the context influence of n+1th token

  block_size = 8
  print(train_data[:block_size])
  x = train_data[:block_size]
  y = train_data[1:block_size+1]
  for t in range(block_size):
      context = x[:t+1]
      target = y[t]
      print(f"at input {context}\n" +
	    f"target {target}")
tensor([18, 47, 56, 57, 58,  1, 15, 47])
at input tensor([18])
target 47
at input tensor([18, 47])
target 56
at input tensor([18, 47, 56])
target 57
at input tensor([18, 47, 56, 57])
target 58
at input tensor([18, 47, 56, 57, 58])
target 1
at input tensor([18, 47, 56, 57, 58,  1])
target 15
at input tensor([18, 47, 56, 57, 58,  1, 15])
target 47
at input tensor([18, 47, 56, 57, 58,  1, 15, 47])
target 58

Note that within the block_size of 8, there are 8 total examples.

  tensor([18, 47, 56, 57, 58,  1, 15, 47])
  at input tensor([18])
  target 47
  at input tensor([18, 47])
  target 56
  at input tensor([18, 47, 56])
  target 57
  at input tensor([18, 47, 56, 57])
  target 58
  at input tensor([18, 47, 56, 57, 58])
  target 1
  at input tensor([18, 47, 56, 57, 58,  1])
  target 15
  at input tensor([18, 47, 56, 57, 58,  1, 15])
  target 47
  at input tensor([18, 47, 56, 57, 58,  1, 15, 47])
  target 58

Now we pack all these examples vertically to create a 4 by 8 tensor:

  torch.manual_seed(1337)
  batch_size = 4
  block_size = 8 # as above

  def get_batch(split):
      data = train_data if split == 'train' else val_data
      ix = torch.randint(len(data) - block_size, (batch_size,))
      x = torch.stack([data[i:i+block_size] for i in ix])
      y = torch.stack([data[i+1:i+block_size+1] for i in ix])
      return x,y

  xb, yb = get_batch('train')
  print('inputs:')
  print(xb.shape)
  print(xb)
  print('targets')
  print(yb.shape)
  print(yb)
inputs:
torch.Size([4, 8])
tensor([[24, 43, 58,  5, 57,  1, 46, 43],
        [44, 53, 56,  1, 58, 46, 39, 58],
        [52, 58,  1, 58, 46, 39, 58,  1],
        [25, 17, 27, 10,  0, 21,  1, 54]])
targets
torch.Size([4, 8])
tensor([[43, 58,  5, 57,  1, 46, 43, 39],
        [53, 56,  1, 58, 46, 39, 58,  1],
        [58,  1, 58, 46, 39, 58,  1, 46],
        [17, 27, 10,  0, 21,  1, 54, 39]])

To make the relationship above between the input and expected output labels, we can unroll the loops:

for b in range(batch_size):
	for t in range(block_size):
		context = xb[b, :t+1]
		target = yb[b,t]
		print(f"at input {context}\n" +
		      f"target {target}")
at input tensor([24])
target 43
at input tensor([24, 43])
target 58
at input tensor([24, 43, 58])
target 5
at input tensor([24, 43, 58,  5])
target 57
at input tensor([24, 43, 58,  5, 57])
target 1
at input tensor([24, 43, 58,  5, 57,  1])
target 46
at input tensor([24, 43, 58,  5, 57,  1, 46])
target 43
at input tensor([24, 43, 58,  5, 57,  1, 46, 43])
target 39
at input tensor([44])
target 53
at input tensor([44, 53])
target 56
at input tensor([44, 53, 56])
target 1
at input tensor([44, 53, 56,  1])
target 58
at input tensor([44, 53, 56,  1, 58])
target 46
at input tensor([44, 53, 56,  1, 58, 46])
target 39
at input tensor([44, 53, 56,  1, 58, 46, 39])
target 58
at input tensor([44, 53, 56,  1, 58, 46, 39, 58])
target 1
at input tensor([52])
target 58
at input tensor([52, 58])
target 1
at input tensor([52, 58,  1])
target 58
at input tensor([52, 58,  1, 58])
target 46
at input tensor([52, 58,  1, 58, 46])
target 39
at input tensor([52, 58,  1, 58, 46, 39])
target 58
at input tensor([52, 58,  1, 58, 46, 39, 58])
target 1
at input tensor([52, 58,  1, 58, 46, 39, 58,  1])
target 46
at input tensor([25])
target 17
at input tensor([25, 17])
target 27
at input tensor([25, 17, 27])
target 10
at input tensor([25, 17, 27, 10])
target 0
at input tensor([25, 17, 27, 10,  0])
target 21
at input tensor([25, 17, 27, 10,  0, 21])
target 1
at input tensor([25, 17, 27, 10,  0, 21,  1])
target 54
at input tensor([25, 17, 27, 10,  0, 21,  1, 54])
target 39

Bigram Language Model

import torch
import torch.nn as nn
from torch.nn import functional as F

torch.manual_seed(1337)

class BigramLanguageModel(nn.Module):

  def __init__(self, vocab_size):
    super().__init__()  # the behaviour of this would depend on nn.Module
    self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)  # a small wrapper on the tensor

  def forward(self, idx, targets=None):

    logits = self.token_embedding_table(idx)

    # Batches, Time, Channels
    if targets is None:
      loss = None
    else:
      B, T, C = logits.shape
      logits = logits.view(B*T, C) # reshaping for the pytorch api
      targets = targets.view(B*T)
      loss = F.cross_entropy(logits,targets) # pytorch's negative log likelihood
    
    return logits, loss

  def generate(self, idx, max_new_tokens):
    for _ in range(max_new_tokens):
      logits, loss = self(idx)
      logits = logits[:,-1,:]
      probs = F.softmax(logits, dim=-1)
      idx_next = torch.multinomial(probs, num_samples=1)
      idx = torch.cat((idx, idx_next), dim=1)
    return idx
  
m = BigramLanguageModel(vocab_size)
logits, loss = m(xb, yb)
print(logits.shape)
print(loss)

print(decode(m.generate(idx = torch.zeros((1, 1), dtype=torch.long), max_new_tokens=100)[0].tolist()))
torch.Size([32, 65])
tensor(4.8786, grad_fn=<NllLossBackward0>)

SKIcLT;AcELMoTbvZv C?nq-QE33:CJqkOKH-q;:la!oiywkHjgChzbQ?u!3bLIgwevmyFJGUGp
wnYWmnxKWWev-tDqXErVKLgJ

Training the Model

optimiser = torch.optim.AdamW(m.parameters(), lr=1e-3)
batch_size = 32
for steps in range(10000):
  xb, yb = get_batch('train')
  logits, loss = m(xb, yb)
  optimiser.zero_grad(set_to_none=True)
  loss.backward()
  optimiser.step()

print(loss.item())
2.455132484436035
print(decode(m.generate(idx = torch.zeros((1, 1), dtype=torch.long), max_new_tokens=500)[0].tolist()))
I yod mse LIFrsay, ERE:
Theath f; Cleicut be thatus,
YCKIth NI aced wietouthel l INGo hbor ot anovKENORoomerevely fafay haprye.
AMa d, f fflet min bestok awir-miqgoun
An
Sodilie gelds hink'stithy herirs y, idses tour zer veswowat is ber tisme!
NUCENGond ber the.
BHY:
TRONors he thasistindr irshathirot h,
LA thak!

AGLI pat,
Liut ber tho;
Fry; ous tho thy My
BELorn. w'I d sio,
T:

Gl, ng e!
OKINEO,
PUCasthace, tho.
Faue.
KIf tho minonthn t he te
yofsts h, ptincofive?'do athen psh peer ts lm vitor

Matrix Algebra for Attention

B, T, C = 4, 8, 2
x = torch.randn(B,T, C)
x.shape

# we want x[b,t] = mean_{i<=t} x[b,i]
xbow = torch.zeros((B,T,C))
for b in range(B):
  for t in range(T):
    xprev = x[b,:t+1]
    xbow[b,t] = torch.mean(xprev, 0)

Note, that this removes spatial information, but is an improvement on our single char context window of the Bigram Model above.

torch.Size([4, 8, 2])
torch.tril(torch.ones(3,3))
torch.triu(torch.ones(3,3))
x[0]
xbow[0]
tensor([[0.7849, 1.3279],
        [0.8636, 1.2060],
        [0.7551, 0.8198],
        [0.3438, 0.5563],
        [0.2827, 0.4735],
        [0.1200, 0.1613],
        [0.1433, 0.1783],
        [0.1595, 0.0065]])

Our results are good, but the time complexity is not. \(O(n^2)\) currently. We can optimise with triangular matrices:

a= torch.tril(torch.ones(3,3))
a = a / torch.sum(a, 1, keepdim = True)
b = torch.randint(0,10,(3,2)).float()
c = a @ b
print('a=')
print(a)
print('--')
print('b=')
print(b)
print('--')
print('c=')
print(c)
a=
tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])
--
b=
tensor([[1., 5.],
        [5., 1.],
        [4., 2.]])
--
c=
tensor([[1.0000, 5.0000],
        [3.0000, 3.0000],
        [3.3333, 2.6667]])

Thus the C tensor takes the average!

Vectorised Xbow2

wei = torch.tril(torch.ones(T, T)) # wei stands for weights
wei = wei / wei.sum(1, keepdim=True) # as above
xbow2 = wei @ x # (B, T, T) @ (B, T, C) = (B, T, C)
wei
#torch.allclose(xbow, xbow2)
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])

Andrej refactors this once more to include a softmax variation. The results are I believe almost identical. I think by polluting the 0's with a small amount of float we improve the robustness of the algorithm. I have seen something similar in the Google PageRank Algorithm.

#SOFTMAX
tril = torch.tril(torch.ones(T, T))
wei = torch.zeros((T,T))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)
xbow3 = wei @ x
torch.allclose(xbow, xbow3)
True

Entering an IDE

to finish the .py file and obtain a val_loss less than 2.

import torch
import torch.nn as nn
from torch.nn import functional as F

#hyperparameters
batch_size = 64
block_size = 256
max_iters = 5000
eval_interval = 1
learning_rate = 3e-4
device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu") #metal m1 mac
# usually you use:
# device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
n_embd = 384
n_head = 6
n_layer = 6
dropout = 0.2
# -----------

torch.manual_seed(1337)

with open('input.txt', 'r', encoding='utf-8') as f:
  text = f.read()

chars = sorted(list(set(text)))
vocab_size = len(chars)
stoi = {ch: i for i, ch in enumerate(chars)}
itos = {i: ch for i, ch in enumerate(chars)}
encode = lambda s: [stoi[c] for c in s]
# defines function taking in string, outputs list of ints
decode = lambda l: ''.join([itos[i] for i in l])
# input: list of integers, outputs string

# splitting training and validation
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]

# data loading
def get_batch(split):
  data = train_data if split == 'train' else val_data
  ix = torch.randint(len(data) - block_size, (batch_size,))
  x = torch.stack([data[i:i+block_size] for i in ix])
  y = torch.stack([data[i+1:i+block_size+1] for i in ix])
  x, y = x.to(device), y.to(device)
  return x,y

# the below context marker is important so torch does not
# load the gradients into memory
@torch.no_grad()
def estimate_loss():
  out = {}
  model.eval()
  for split in ['train', 'val']:
    losses = torch.zeros(eval_iters)
    for k in range(eval_iters):
        X, Y = get_batch(split)
        logits, loss = model(X, Y)
        losses[k] = loss.item()
    out[split] = losses.mean()
  model.train()
  return out

class Head(nn.Module):
  """ one head of self-attention """
  def __init__(self, head_size):
    super().__init__()
    self.key = nn.Linear(n_embd, head_size, bias=False)
    self.query = nn.Linear(n_embd, head_size, bias=False)
    self.value = nn.Linear(n_embd, head_size, bias=False)
    self.register_buffer('tril',torch.tril(torch.ones(block_size, block_size)))

    self.dropout = nn.Dropout(dropout)

  def forward(self, x):
    # input:  (batch, time-step, channels)
    # output: (batch, time-step, head size)
    B, T, C = x.shape
    k = self.key(x)
    q = self.query(x)

    # affinities
    wei = q @ k.transpose(-2,-1) * k.shape[-1]**-0.5
    wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
    wei = F.softmax(wei, dim=-1)
    wei = self.dropout(wei)
    # weighted aggregation:
    v = self.value(x)
    out = wei @ v
    return out

class MultiHeadAttention(nn.Module):
  """ multiple heads of self-attention in parallel """
  def __init__(self, num_heads, head_size):
    super().__init__()
    self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
    self.proj = nn.Linear(head_size * num_heads, n_embd)
    self.dropout = nn.Dropout(dropout)

  def forward(self, x):
    out = torch.cat([h(x) for h in self.heads], dim=-1)
    out = self.dropout(self.proj(out))
    return out

class FeedForward(nn.Module):
  """ a simple linear layer followed by non-linearity """

  def __init__(self, n_embd):
    super().__init__()
    self.net = nn.Sequential(
      nn.Linear(n_embd, 4 * n_embd),
      nn.ReLU(),
      nn.Linear(4*n_embd, n_embd),
      nn.Dropout(dropout)
    )

  def forward(self, x):
    return self.net(x)

class Block(nn.Module):
  """ Transformer block: communication and computation """

  def __init__(self, n_embd, n_head):
    # n_embd = embedding dimension,
    # n_head = number of heads we want
    super().__init__()
    head_size = n_embd // n_head
    self.sa = MultiHeadAttention(n_head, head_size)
    self.ffwd = FeedForward(n_embd)
    self.ln1 = nn.LayerNorm(n_embd)
    self.ln2 = nn.LayerNorm(n_embd)

  def forward(self, x):
    x = x + self.sa(self.ln1(x))
    x = x + self.ffwd(self.ln2(x))
    return x

class GPTLanguageModel(nn.Module):

  def __init__(self):
    super().__init__()
    self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
    self.position_embedding_table = nn.Embedding(block_size, n_embd)
    self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)])
    self.ln_f = nn.LayerNorm(n_embd)
    self.lm_head = nn.Linear(n_embd, vocab_size)

    # better init, andrej followup
    self.apply(self._init_weights)

  def _init_weights(self, module):
    if isinstance(module, nn.Linear):
      torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
      if module.bias is not None:
        torch.nn.init.zeros_(module.bias)
      elif isinstance(module, nn.Embedding):
        torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

  def forward(self, idx, targets=None):
    B, T = idx.shape
    tok_emb = self.token_embedding_table(idx)
    pos_emb = self.position_embedding_table(torch.arange(T, device=device))
    x = tok_emb + pos_emb
    x = self.blocks(x)
    x = self.ln_f(x)
    logits = self.lm_head(x)

    if targets is None:
      loss = None
    else:
      B,T,C = logits.shape
      logits = logits.view(B*T, C)
      targets = targets.view(B*T)
      loss = F.cross_entropy(logits, targets)
    return logits, loss

  def generate(self, idx, max_new_tokens):
    for _ in range(max_new_tokens):
      idx_cond = idx[:, -block_size:]
      logits, loss = self(idx_cond)
      logits = logits[:,-1,:]
      probs = F.softmax(logits, dim=-1)
      idx_next = torch.multinomial(probs, num_samples=1)
      idx = torch.cat((idx, idx_next), dim=1)
    return idx

model = GPTLanguageModel()
m = model.to(device)
# number of parameters in the model:
print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters')

optimiser = torch.optim.AdamW(model.parameters(), lr=learning_rate)

for iter in range(max_iters):
  # periodically evaluate loss on train and val sets:
  if iter % eval_interval == 0 or iter == max_iters - 1:
    losses = estimate_loss()
    print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

    # sample a batch
    xb, yb = get_batch('train')

    # evaluate loss
    logits, loss = model(xb, yb)
    optimiser.zero_grad(set_to_none=True)
    loss.backward()
    optimiser.step()

context = torch.zeros((1,1),dtype=torch.long, device=device)
print(decode(m.generate(context, max_new_tokens=500)[0].tolist()))

Results

After training on UNSW HPC Katana with 1GPU and 120GB of RAM for 12hours we got through 2000 training iterations of the above code and produced the following output:

MERDIA:
At visage a tritors:
Reignation comes it Juliet,
For Edward alived Menta's those percisicious
That laid go my kings what not:
If not be heads with say the horst house lothes,
As if no land paints her.
Let in there come such
From Keep'd, I am must be, and neath
For my shurseix'd exatrent's a base or a ride,
As I am gnown?

GLOUCESTER:
Stir, whetchet will death I hal,
Heaven no.

KING EDWARD IV:
What spenity these friends,
Henry to speak shall be my breast
Wherein so rody to life thyself are were glorn
Leys not vengeant with end gommends, and Marcius
Upon rice his cold, and his hand; and a boind
I hoping rof in perrow nothings
Of if honest on our sovereign, prot commonhes,
Be the of less of wills, I this dirt by the oten,
Which art I seeps a land connice;
I thout our pales will obert
Cause intreward's pastiencey's call.
He hateful vantage,
Corsal affect'd Richard's light: he is fourness,
So trad out is come to delive,
For a burse light's to diely panion.

AUTICARES:
He shall, thirs, my lord?

KING RICHARD III:
Why, thou, darest that I say
Here before thee, sirs, this lays at hoPt me.
Post, like his lass'd find hard;
And him eye goes must grown am I came.

NORTHUrsman:
Can we thou, I was not leave him befort the lame?

KING RICHARD II:
Wespert it he hath sheep-death'd!

ISABELLA:
Ah, my lord, I swarrant mind.

GREY:
Sir, I halft to laid time on upen
A lift his troal brawl'd more with lives:
This is tove to which to worstham in two;
And ther no more comes in Ladne opast
To eve, givew us gold, ere must of what made juguefy
All out of a company as Right!
Captain! God, I lose good near go;
Cold mising dies and unto my slay,
In hop now I see grows but a kiss,
Not, sir, I see hitheld; On, best your
Not no craft the base of Lord Sorrances. You are your
Come own and hor hand!

Third Servent:
Sure her! how despects, the good wof the again,
The encounter of blood atter long of,
The paint put of ciusin; an you see, and satisted,
I propants notting, blling on C?

HENRY BOLINNE:
O leave to back to Bencam and too.

NORTANL:
My royal,
Shether English beson, thou cry to thee,
When 'twicing as Plasant a purtian?'

MAGSONUS:
I have we with every them tore,
Pale, my earthing be't it lively to pail a land,
Where not destrove that one a minds,
Being partue, and then exenement
My noble that, sir.

BENVOLIO:
I will be provedent not.

WARWICK:
Ay, sick, my sover's daughter away.