NanoGPT - Min with Teeth
Andrej Karpathy Video
Code
Pulling the dataset we will be working on:
curl https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt -o input.txt
Reading it into python
with open('input.txt', 'r', encoding='utf-8') as f:
text = f.read()
Data inspection
print("length of dataset in characters: ", len(text))
print("length of data: ", len(data))
length of dataset in characters: 1115394 length of data: 1115394
print(text[:1000])
First Citizen: Before we proceed any further, hear me speak. All: Speak, speak. First Citizen: You are all resolved rather to die than to famish? All: Resolved. resolved. First Citizen: First, you know Caius Marcius is chief enemy to the people. All: We know't, we know't. First Citizen: Let us kill him, and we'll have corn at our own price. Is't a verdict? All: No more talking on't; let it be done: away, away! Second Citizen: One word, good citizens. First Citizen: We are accounted poor citizens, the patricians good. What authority surfeits on would relieve us: if they would yield us but the superfluity, while it were wholesome, we might guess they relieved us humanely; but they think we are too dear: the leanness that afflicts us, the object of our misery, is as an inventory to particularise their abundance; our sufferance is a gain to them Let us revenge this with our pikes, ere we become rakes: for the gods know I speak this in hunger for bread, not in thirst for revenge.
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(''.join(chars))
print(vocab_size)
!$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz 65
Tokeniser
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s]
# defines function taking in string, outputs list of ints
decode = lambda l: ''.join([itos[i] for i in l])
# input: list of integers, outputs string
print(encode("hello world"))
print(decode(encode("hello world")))
[46, 43, 50, 50, 53, 1, 61, 53, 56, 50, 42] hello world
import torch
data = torch.tensor(encode(text), dtype=torch.long)
print(data.shape, data.dtype)
print(data[:1000])
torch.Size([1115394]) torch.int64 tensor([18, 47, 56, 57, 58, 1, 15, 47, 58, 47, 64, 43, 52, 10, 0, 14, 43, 44, 53, 56, 43, 1, 61, 43, 1, 54, 56, 53, 41, 43, 43, 42, 1, 39, 52, 63, 1, 44, 59, 56, 58, 46, 43, 56, 6, 1, 46, 43, 39, 56, 1, 51, 43, 1, 57, 54, 43, 39, 49, 8, 0, 0, 13, 50, 50, 10, 0, 31, 54, 43, 39, 49, 6, 1, 57, 54, 43, 39, 49, 8, 0, 0, 18, 47, 56, 57, 58, 1, 15, 47, 58, 47, 64, 43, 52, 10, 0, 37, 53, 59, 1, 39, 56, 43, 1, 39, 50, 50, 1, 56, 43, 57, 53, 50, 60, 43, 42, 1, 56, 39, 58, 46, 43, 56, 1, 58, 53, 1, 42, 47, 43, 1, 58, 46, 39, 52, 1, 58, 53, 1, 44, 39, 51, 47, 57, 46, 12, 0, 0, 13, 50, 50, 10, 0, 30, 43, 57, 53, 50, 60, 43, 42, 8, 1, 56, 43, 57, 53, 50, 60, 43, 42, 8, 0, 0, 18, 47, 56, 57, 58, 1, 15, 47, 58, 47, 64, 43, 52, 10, 0, 18, 47, 56, 57, 58, 6, 1, 63, 53, 59, 1, 49, 52, 53, 61, 1, 15, 39, 47, 59, 57, 1, 25, 39, 56, 41, 47, 59, 57, 1, 47, 57, 1, 41, 46, 47, 43, 44, 1, 43, 52, 43, 51, 63, 1, 58, 53, 1, 58, 46, 43, 1, 54, 43, 53, 54, 50, 43, 8, 0, 0, 13, 50, 50, 10, 0, 35, 43, 1, 49, 52, 53, 61, 5, 58, 6, 1, 61, 43, 1, 49, 52, 53, 61, 5, 58, 8, 0, 0, 18, 47, 56, 57, 58, 1, 15, 47, 58, 47, 64, 43, 52, 10, 0, 24, 43, 58, 1, 59, 57, 1, 49, 47, 50, 50, 1, 46, 47, 51, 6, 1, 39, 52, 42, 1, 61, 43, 5, 50, 50, 1, 46, 39, 60, 43, 1, 41, 53, 56, 52, 1, 39, 58, 1, 53, 59, 56, 1, 53, 61, 52, 1, 54, 56, 47, 41, 43, 8, 0, 21, 57, 5, 58, 1, 39, 1, 60, 43, 56, 42, 47, 41, 58, 12, 0, 0, 13, 50, 50, 10, 0, 26, 53, 1, 51, 53, 56, 43, 1, 58, 39, 50, 49, 47, 52, 45, 1, 53, 52, 5, 58, 11, 1, 50, 43, 58, 1, 47, 58, 1, 40, 43, 1, 42, 53, 52, 43, 10, 1, 39, 61, 39, 63, 6, 1, 39, 61, 39, 63, 2, 0, 0, 31, 43, 41, 53, 52, 42, 1, 15, 47, 58, 47, 64, 43, 52, 10, 0, 27, 52, 43, 1, 61, 53, 56, 42, 6, 1, 45, 53, 53, 42, 1, 41, 47, 58, 47, 64, 43, 52, 57, 8, 0, 0, 18, 47, 56, 57, 58, 1, 15, 47, 58, 47, 64, 43, 52, 10, 0, 35, 43, 1, 39, 56, 43, 1, 39, 41, 41, 53, 59, 52, 58, 43, 42, 1, 54, 53, 53, 56, 1, 41, 47, 58, 47, 64, 43, 52, 57, 6, 1, 58, 46, 43, 1, 54, 39, 58, 56, 47, 41, 47, 39, 52, 57, 1, 45, 53, 53, 42, 8, 0, 35, 46, 39, 58, 1, 39, 59, 58, 46, 53, 56, 47, 58, 63, 1, 57, 59, 56, 44, 43, 47, 58, 57, 1, 53, 52, 1, 61, 53, 59, 50, 42, 1, 56, 43, 50, 47, 43, 60, 43, 1, 59, 57, 10, 1, 47, 44, 1, 58, 46, 43, 63, 0, 61, 53, 59, 50, 42, 1, 63, 47, 43, 50, 42, 1, 59, 57, 1, 40, 59, 58, 1, 58, 46, 43, 1, 57, 59, 54, 43, 56, 44, 50, 59, 47, 58, 63, 6, 1, 61, 46, 47, 50, 43, 1, 47, 58, 1, 61, 43, 56, 43, 0, 61, 46, 53, 50, 43, 57, 53, 51, 43, 6, 1, 61, 43, 1, 51, 47, 45, 46, 58, 1, 45, 59, 43, 57, 57, 1, 58, 46, 43, 63, 1, 56, 43, 50, 47, 43, 60, 43, 42, 1, 59, 57, 1, 46, 59, 51, 39, 52, 43, 50, 63, 11, 0, 40, 59, 58, 1, 58, 46, 43, 63, 1, 58, 46, 47, 52, 49, 1, 61, 43, 1, 39, 56, 43, 1, 58, 53, 53, 1, 42, 43, 39, 56, 10, 1, 58, 46, 43, 1, 50, 43, 39, 52, 52, 43, 57, 57, 1, 58, 46, 39, 58, 0, 39, 44, 44, 50, 47, 41, 58, 57, 1, 59, 57, 6, 1, 58, 46, 43, 1, 53, 40, 48, 43, 41, 58, 1, 53, 44, 1, 53, 59, 56, 1, 51, 47, 57, 43, 56, 63, 6, 1, 47, 57, 1, 39, 57, 1, 39, 52, 0, 47, 52, 60, 43, 52, 58, 53, 56, 63, 1, 58, 53, 1, 54, 39, 56, 58, 47, 41, 59, 50, 39, 56, 47, 57, 43, 1, 58, 46, 43, 47, 56, 1, 39, 40, 59, 52, 42, 39, 52, 41, 43, 11, 1, 53, 59, 56, 0, 57, 59, 44, 44, 43, 56, 39, 52, 41, 43, 1, 47, 57, 1, 39, 1, 45, 39, 47, 52, 1, 58, 53, 1, 58, 46, 43, 51, 1, 24, 43, 58, 1, 59, 57, 1, 56, 43, 60, 43, 52, 45, 43, 1, 58, 46, 47, 57, 1, 61, 47, 58, 46, 0, 53, 59, 56, 1, 54, 47, 49, 43, 57, 6, 1, 43, 56, 43, 1, 61, 43, 1, 40, 43, 41, 53, 51, 43, 1, 56, 39, 49, 43, 57, 10, 1, 44, 53, 56, 1, 58, 46, 43, 1, 45, 53, 42, 57, 1, 49, 52, 53, 61, 1, 21, 0, 57, 54, 43, 39, 49, 1, 58, 46, 47, 57, 1, 47, 52, 1, 46, 59, 52, 45, 43, 56, 1, 44, 53, 56, 1, 40, 56, 43, 39, 42, 6, 1, 52, 53, 58, 1, 47, 52, 1, 58, 46, 47, 56, 57, 58, 1, 44, 53, 56, 1, 56, 43, 60, 43, 52, 45, 43, 8, 0, 0])
n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]
Understanding the context influence of n+1th token
block_size = 8
print(train_data[:block_size])
x = train_data[:block_size]
y = train_data[1:block_size+1]
for t in range(block_size):
context = x[:t+1]
target = y[t]
print(f"at input {context}\n" +
f"target {target}")
tensor([18, 47, 56, 57, 58, 1, 15, 47]) at input tensor([18]) target 47 at input tensor([18, 47]) target 56 at input tensor([18, 47, 56]) target 57 at input tensor([18, 47, 56, 57]) target 58 at input tensor([18, 47, 56, 57, 58]) target 1 at input tensor([18, 47, 56, 57, 58, 1]) target 15 at input tensor([18, 47, 56, 57, 58, 1, 15]) target 47 at input tensor([18, 47, 56, 57, 58, 1, 15, 47]) target 58
Note that within the block_size of 8, there are 8 total examples.
tensor([18, 47, 56, 57, 58, 1, 15, 47]) at input tensor([18]) target 47 at input tensor([18, 47]) target 56 at input tensor([18, 47, 56]) target 57 at input tensor([18, 47, 56, 57]) target 58 at input tensor([18, 47, 56, 57, 58]) target 1 at input tensor([18, 47, 56, 57, 58, 1]) target 15 at input tensor([18, 47, 56, 57, 58, 1, 15]) target 47 at input tensor([18, 47, 56, 57, 58, 1, 15, 47]) target 58
Now we pack all these examples vertically to create a 4 by 8 tensor:
torch.manual_seed(1337)
batch_size = 4
block_size = 8 # as above
def get_batch(split):
data = train_data if split == 'train' else val_data
ix = torch.randint(len(data) - block_size, (batch_size,))
x = torch.stack([data[i:i+block_size] for i in ix])
y = torch.stack([data[i+1:i+block_size+1] for i in ix])
return x,y
xb, yb = get_batch('train')
print('inputs:')
print(xb.shape)
print(xb)
print('targets')
print(yb.shape)
print(yb)
inputs: torch.Size([4, 8]) tensor([[24, 43, 58, 5, 57, 1, 46, 43], [44, 53, 56, 1, 58, 46, 39, 58], [52, 58, 1, 58, 46, 39, 58, 1], [25, 17, 27, 10, 0, 21, 1, 54]]) targets torch.Size([4, 8]) tensor([[43, 58, 5, 57, 1, 46, 43, 39], [53, 56, 1, 58, 46, 39, 58, 1], [58, 1, 58, 46, 39, 58, 1, 46], [17, 27, 10, 0, 21, 1, 54, 39]])
To make the relationship above between the input and expected output labels, we can unroll the loops:
for b in range(batch_size):
for t in range(block_size):
context = xb[b, :t+1]
target = yb[b,t]
print(f"at input {context}\n" +
f"target {target}")
at input tensor([24]) target 43 at input tensor([24, 43]) target 58 at input tensor([24, 43, 58]) target 5 at input tensor([24, 43, 58, 5]) target 57 at input tensor([24, 43, 58, 5, 57]) target 1 at input tensor([24, 43, 58, 5, 57, 1]) target 46 at input tensor([24, 43, 58, 5, 57, 1, 46]) target 43 at input tensor([24, 43, 58, 5, 57, 1, 46, 43]) target 39 at input tensor([44]) target 53 at input tensor([44, 53]) target 56 at input tensor([44, 53, 56]) target 1 at input tensor([44, 53, 56, 1]) target 58 at input tensor([44, 53, 56, 1, 58]) target 46 at input tensor([44, 53, 56, 1, 58, 46]) target 39 at input tensor([44, 53, 56, 1, 58, 46, 39]) target 58 at input tensor([44, 53, 56, 1, 58, 46, 39, 58]) target 1 at input tensor([52]) target 58 at input tensor([52, 58]) target 1 at input tensor([52, 58, 1]) target 58 at input tensor([52, 58, 1, 58]) target 46 at input tensor([52, 58, 1, 58, 46]) target 39 at input tensor([52, 58, 1, 58, 46, 39]) target 58 at input tensor([52, 58, 1, 58, 46, 39, 58]) target 1 at input tensor([52, 58, 1, 58, 46, 39, 58, 1]) target 46 at input tensor([25]) target 17 at input tensor([25, 17]) target 27 at input tensor([25, 17, 27]) target 10 at input tensor([25, 17, 27, 10]) target 0 at input tensor([25, 17, 27, 10, 0]) target 21 at input tensor([25, 17, 27, 10, 0, 21]) target 1 at input tensor([25, 17, 27, 10, 0, 21, 1]) target 54 at input tensor([25, 17, 27, 10, 0, 21, 1, 54]) target 39
Bigram Language Model
import torch
import torch.nn as nn
from torch.nn import functional as F
torch.manual_seed(1337)
class BigramLanguageModel(nn.Module):
def __init__(self, vocab_size):
super().__init__() # the behaviour of this would depend on nn.Module
self.token_embedding_table = nn.Embedding(vocab_size, vocab_size) # a small wrapper on the tensor
def forward(self, idx, targets=None):
logits = self.token_embedding_table(idx)
# Batches, Time, Channels
if targets is None:
loss = None
else:
B, T, C = logits.shape
logits = logits.view(B*T, C) # reshaping for the pytorch api
targets = targets.view(B*T)
loss = F.cross_entropy(logits,targets) # pytorch's negative log likelihood
return logits, loss
def generate(self, idx, max_new_tokens):
for _ in range(max_new_tokens):
logits, loss = self(idx)
logits = logits[:,-1,:]
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1)
idx = torch.cat((idx, idx_next), dim=1)
return idx
m = BigramLanguageModel(vocab_size)
logits, loss = m(xb, yb)
print(logits.shape)
print(loss)
print(decode(m.generate(idx = torch.zeros((1, 1), dtype=torch.long), max_new_tokens=100)[0].tolist()))
torch.Size([32, 65]) tensor(4.8786, grad_fn=<NllLossBackward0>) SKIcLT;AcELMoTbvZv C?nq-QE33:CJqkOKH-q;:la!oiywkHjgChzbQ?u!3bLIgwevmyFJGUGp wnYWmnxKWWev-tDqXErVKLgJ
Training the Model
optimiser = torch.optim.AdamW(m.parameters(), lr=1e-3)
batch_size = 32
for steps in range(10000):
xb, yb = get_batch('train')
logits, loss = m(xb, yb)
optimiser.zero_grad(set_to_none=True)
loss.backward()
optimiser.step()
print(loss.item())
2.455132484436035
print(decode(m.generate(idx = torch.zeros((1, 1), dtype=torch.long), max_new_tokens=500)[0].tolist()))
I yod mse LIFrsay, ERE: Theath f; Cleicut be thatus, YCKIth NI aced wietouthel l INGo hbor ot anovKENORoomerevely fafay haprye. AMa d, f fflet min bestok awir-miqgoun An Sodilie gelds hink'stithy herirs y, idses tour zer veswowat is ber tisme! NUCENGond ber the. BHY: TRONors he thasistindr irshathirot h, LA thak! AGLI pat, Liut ber tho; Fry; ous tho thy My BELorn. w'I d sio, T: Gl, ng e! OKINEO, PUCasthace, tho. Faue. KIf tho minonthn t he te yofsts h, ptincofive?'do athen psh peer ts lm vitor
Matrix Algebra for Attention
B, T, C = 4, 8, 2
x = torch.randn(B,T, C)
x.shape
# we want x[b,t] = mean_{i<=t} x[b,i]
xbow = torch.zeros((B,T,C))
for b in range(B):
for t in range(T):
xprev = x[b,:t+1]
xbow[b,t] = torch.mean(xprev, 0)
Note, that this removes spatial information, but is an improvement on our single char context window of the Bigram Model above.
torch.Size([4, 8, 2])
torch.tril(torch.ones(3,3))
torch.triu(torch.ones(3,3))
x[0]
xbow[0]
tensor([[0.7849, 1.3279], [0.8636, 1.2060], [0.7551, 0.8198], [0.3438, 0.5563], [0.2827, 0.4735], [0.1200, 0.1613], [0.1433, 0.1783], [0.1595, 0.0065]])
Our results are good, but the time complexity is not. \(O(n^2)\) currently. We can optimise with triangular matrices:
a= torch.tril(torch.ones(3,3))
a = a / torch.sum(a, 1, keepdim = True)
b = torch.randint(0,10,(3,2)).float()
c = a @ b
print('a=')
print(a)
print('--')
print('b=')
print(b)
print('--')
print('c=')
print(c)
a= tensor([[1.0000, 0.0000, 0.0000], [0.5000, 0.5000, 0.0000], [0.3333, 0.3333, 0.3333]]) -- b= tensor([[1., 5.], [5., 1.], [4., 2.]]) -- c= tensor([[1.0000, 5.0000], [3.0000, 3.0000], [3.3333, 2.6667]])
Thus the C tensor takes the average!
Vectorised Xbow2
wei = torch.tril(torch.ones(T, T)) # wei stands for weights
wei = wei / wei.sum(1, keepdim=True) # as above
xbow2 = wei @ x # (B, T, T) @ (B, T, C) = (B, T, C)
wei
#torch.allclose(xbow, xbow2)
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000], [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000], [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000], [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000], [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])
Andrej refactors this once more to include a softmax variation. The results are I believe almost identical. I think by polluting the 0's with a small amount of float we improve the robustness of the algorithm. I have seen something similar in the Google PageRank Algorithm.
#SOFTMAX
tril = torch.tril(torch.ones(T, T))
wei = torch.zeros((T,T))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)
xbow3 = wei @ x
torch.allclose(xbow, xbow3)
True
Entering an IDE
to finish the .py
file and obtain a val_loss less than 2.
import torch
import torch.nn as nn
from torch.nn import functional as F
#hyperparameters
batch_size = 64
block_size = 256
max_iters = 5000
eval_interval = 1
learning_rate = 3e-4
device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu") #metal m1 mac
# usually you use:
# device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
n_embd = 384
n_head = 6
n_layer = 6
dropout = 0.2
# -----------
torch.manual_seed(1337)
with open('input.txt', 'r', encoding='utf-8') as f:
text = f.read()
chars = sorted(list(set(text)))
vocab_size = len(chars)
stoi = {ch: i for i, ch in enumerate(chars)}
itos = {i: ch for i, ch in enumerate(chars)}
encode = lambda s: [stoi[c] for c in s]
# defines function taking in string, outputs list of ints
decode = lambda l: ''.join([itos[i] for i in l])
# input: list of integers, outputs string
# splitting training and validation
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]
# data loading
def get_batch(split):
data = train_data if split == 'train' else val_data
ix = torch.randint(len(data) - block_size, (batch_size,))
x = torch.stack([data[i:i+block_size] for i in ix])
y = torch.stack([data[i+1:i+block_size+1] for i in ix])
x, y = x.to(device), y.to(device)
return x,y
# the below context marker is important so torch does not
# load the gradients into memory
@torch.no_grad()
def estimate_loss():
out = {}
model.eval()
for split in ['train', 'val']:
losses = torch.zeros(eval_iters)
for k in range(eval_iters):
X, Y = get_batch(split)
logits, loss = model(X, Y)
losses[k] = loss.item()
out[split] = losses.mean()
model.train()
return out
class Head(nn.Module):
""" one head of self-attention """
def __init__(self, head_size):
super().__init__()
self.key = nn.Linear(n_embd, head_size, bias=False)
self.query = nn.Linear(n_embd, head_size, bias=False)
self.value = nn.Linear(n_embd, head_size, bias=False)
self.register_buffer('tril',torch.tril(torch.ones(block_size, block_size)))
self.dropout = nn.Dropout(dropout)
def forward(self, x):
# input: (batch, time-step, channels)
# output: (batch, time-step, head size)
B, T, C = x.shape
k = self.key(x)
q = self.query(x)
# affinities
wei = q @ k.transpose(-2,-1) * k.shape[-1]**-0.5
wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)
wei = self.dropout(wei)
# weighted aggregation:
v = self.value(x)
out = wei @ v
return out
class MultiHeadAttention(nn.Module):
""" multiple heads of self-attention in parallel """
def __init__(self, num_heads, head_size):
super().__init__()
self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
self.proj = nn.Linear(head_size * num_heads, n_embd)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
out = torch.cat([h(x) for h in self.heads], dim=-1)
out = self.dropout(self.proj(out))
return out
class FeedForward(nn.Module):
""" a simple linear layer followed by non-linearity """
def __init__(self, n_embd):
super().__init__()
self.net = nn.Sequential(
nn.Linear(n_embd, 4 * n_embd),
nn.ReLU(),
nn.Linear(4*n_embd, n_embd),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class Block(nn.Module):
""" Transformer block: communication and computation """
def __init__(self, n_embd, n_head):
# n_embd = embedding dimension,
# n_head = number of heads we want
super().__init__()
head_size = n_embd // n_head
self.sa = MultiHeadAttention(n_head, head_size)
self.ffwd = FeedForward(n_embd)
self.ln1 = nn.LayerNorm(n_embd)
self.ln2 = nn.LayerNorm(n_embd)
def forward(self, x):
x = x + self.sa(self.ln1(x))
x = x + self.ffwd(self.ln2(x))
return x
class GPTLanguageModel(nn.Module):
def __init__(self):
super().__init__()
self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
self.position_embedding_table = nn.Embedding(block_size, n_embd)
self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)])
self.ln_f = nn.LayerNorm(n_embd)
self.lm_head = nn.Linear(n_embd, vocab_size)
# better init, andrej followup
self.apply(self._init_weights)
def _init_weights(self, module):
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
def forward(self, idx, targets=None):
B, T = idx.shape
tok_emb = self.token_embedding_table(idx)
pos_emb = self.position_embedding_table(torch.arange(T, device=device))
x = tok_emb + pos_emb
x = self.blocks(x)
x = self.ln_f(x)
logits = self.lm_head(x)
if targets is None:
loss = None
else:
B,T,C = logits.shape
logits = logits.view(B*T, C)
targets = targets.view(B*T)
loss = F.cross_entropy(logits, targets)
return logits, loss
def generate(self, idx, max_new_tokens):
for _ in range(max_new_tokens):
idx_cond = idx[:, -block_size:]
logits, loss = self(idx_cond)
logits = logits[:,-1,:]
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1)
idx = torch.cat((idx, idx_next), dim=1)
return idx
model = GPTLanguageModel()
m = model.to(device)
# number of parameters in the model:
print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters')
optimiser = torch.optim.AdamW(model.parameters(), lr=learning_rate)
for iter in range(max_iters):
# periodically evaluate loss on train and val sets:
if iter % eval_interval == 0 or iter == max_iters - 1:
losses = estimate_loss()
print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
# sample a batch
xb, yb = get_batch('train')
# evaluate loss
logits, loss = model(xb, yb)
optimiser.zero_grad(set_to_none=True)
loss.backward()
optimiser.step()
context = torch.zeros((1,1),dtype=torch.long, device=device)
print(decode(m.generate(context, max_new_tokens=500)[0].tolist()))
Results
After training on UNSW HPC Katana with 1GPU and 120GB of RAM for 12hours we got through 2000 training iterations of the above code and produced the following output:
MERDIA:
At visage a tritors:
Reignation comes it Juliet,
For Edward alived Menta's those percisicious
That laid go my kings what not:
If not be heads with say the horst house lothes,
As if no land paints her.
Let in there come such
From Keep'd, I am must be, and neath
For my shurseix'd exatrent's a base or a ride,
As I am gnown?
GLOUCESTER:
Stir, whetchet will death I hal,
Heaven no.
KING EDWARD IV:
What spenity these friends,
Henry to speak shall be my breast
Wherein so rody to life thyself are were glorn
Leys not vengeant with end gommends, and Marcius
Upon rice his cold, and his hand; and a boind
I hoping rof in perrow nothings
Of if honest on our sovereign, prot commonhes,
Be the of less of wills, I this dirt by the oten,
Which art I seeps a land connice;
I thout our pales will obert
Cause intreward's pastiencey's call.
He hateful vantage,
Corsal affect'd Richard's light: he is fourness,
So trad out is come to delive,
For a burse light's to diely panion.
AUTICARES:
He shall, thirs, my lord?
KING RICHARD III:
Why, thou, darest that I say
Here before thee, sirs, this lays at hoPt me.
Post, like his lass'd find hard;
And him eye goes must grown am I came.
NORTHUrsman:
Can we thou, I was not leave him befort the lame?
KING RICHARD II:
Wespert it he hath sheep-death'd!
ISABELLA:
Ah, my lord, I swarrant mind.
GREY:
Sir, I halft to laid time on upen
A lift his troal brawl'd more with lives:
This is tove to which to worstham in two;
And ther no more comes in Ladne opast
To eve, givew us gold, ere must of what made juguefy
All out of a company as Right!
Captain! God, I lose good near go;
Cold mising dies and unto my slay,
In hop now I see grows but a kiss,
Not, sir, I see hitheld; On, best your
Not no craft the base of Lord Sorrances. You are your
Come own and hor hand!
Third Servent:
Sure her! how despects, the good wof the again,
The encounter of blood atter long of,
The paint put of ciusin; an you see, and satisted,
I propants notting, blling on C?
HENRY BOLINNE:
O leave to back to Bencam and too.
NORTANL:
My royal,
Shether English beson, thou cry to thee,
When 'twicing as Plasant a purtian?'
MAGSONUS:
I have we with every them tore,
Pale, my earthing be't it lively to pail a land,
Where not destrove that one a minds,
Being partue, and then exenement
My noble that, sir.
BENVOLIO:
I will be provedent not.
WARWICK:
Ay, sick, my sover's daughter away.