{"cells":[{"cell_type":"markdown","metadata":{"id":"jGHnHFav4lGi"},"source":["# Projection, Layernorm, Dropout\n","\n","\n"," \n","
\n"," \"Open\n","
\n","\n","
\n","
"]},{"cell_type":"code","execution_count":1,"metadata":{"executionInfo":{"elapsed":775,"status":"ok","timestamp":1717903569854,"user":{"displayName":"Abarajithan Gnaneswaran","userId":"15065416675145289008"},"user_tz":420},"id":"TzvEm56B4cd1"},"outputs":[],"source":["import torch\n","import torch.nn as nn\n","from torch.nn import functional as F"]},{"cell_type":"markdown","metadata":{"id":"rhgC0PRC5a8Q"},"source":["### Hyperparameters"]},{"cell_type":"code","execution_count":2,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":255,"status":"ok","timestamp":1717903570107,"user":{"displayName":"Abarajithan Gnaneswaran","userId":"15065416675145289008"},"user_tz":420},"id":"U-hPaOJH5aFU","outputId":"c6956209-4e32-4642-ae28-1453f1e9e240"},"outputs":[{"data":{"text/plain":[""]},"execution_count":2,"metadata":{},"output_type":"execute_result"}],"source":["'''Hyperparameters for smaller model'''\n","\n","B = 32 # B: how many independent sequences will we process in parallel?\n","T = 8 # T: what is the maximum context length for predictions?\n","C = 32 # C: numer of different features analysed (also D = dims)\n","H = 4 # H: number of attention heads\n","L = 4 # L: Number of layers\n","learning_rate = 1e-3\n","\n","'''Final Hyperparameters'''\n","\n","# B = 64 # B: how many independent sequences will we process in parallel?\n","# T = 256 # T: what is the maximum context length for predictions?\n","# H = 6\n","# C = 64*H\n","# L = 6\n","# learning_rate = 1e-4\n","\n","# Common Hyperparameters\n","max_iters = 5000\n","eval_interval = 500\n","device = 'cuda' if torch.cuda.is_available() else 'cpu'\n","eval_iters = 200\n","dropout = 0.2\n","torch.manual_seed(1337)"]},{"cell_type":"markdown","metadata":{"id":"yLFQvyFf4taN"},"source":["### Data"]},{"cell_type":"code","execution_count":3,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":206,"status":"ok","timestamp":1717903570310,"user":{"displayName":"Abarajithan Gnaneswaran","userId":"15065416675145289008"},"user_tz":420},"id":"NwgGJ0-K4rx2","outputId":"f960e9f8-2848-417e-91c6-09f85dc788a9"},"outputs":[{"name":"stdout","output_type":"stream","text":["--2024-06-09 03:26:08-- https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt\n","Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.109.133, 185.199.110.133, 185.199.108.133, ...\n","Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.109.133|:443... connected.\n","HTTP request sent, awaiting response... 200 OK\n","Length: 1115394 (1.1M) [text/plain]\n","Saving to: ‘input.txt.1’\n","\n","\rinput.txt.1 0%[ ] 0 --.-KB/s \rinput.txt.1 100%[===================>] 1.06M --.-KB/s in 0.03s \n","\n","2024-06-09 03:26:08 (31.2 MB/s) - ‘input.txt.1’ saved [1115394/1115394]\n","\n"]}],"source":["!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"]},{"cell_type":"code","execution_count":4,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":241,"status":"ok","timestamp":1717903570546,"user":{"displayName":"Abarajithan Gnaneswaran","userId":"15065416675145289008"},"user_tz":420},"id":"Hbw16RqT4wFg","outputId":"bb9f569b-e5d3-45e6-b4de-b1f9f66f2e45"},"outputs":[{"name":"stdout","output_type":"stream","text":["vocab_size: 65\n","vocabulary: \n"," !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz\n"]}],"source":["with open('input.txt', 'r', encoding='utf-8') as f:\n"," text = f.read()\n","\n","# here are all the unique characters that occur in this text\n","chars = sorted(list(set(text)))\n","vocab_size = len(chars)\n","# create a mapping from characters to integers\n","stoi = { ch:i for i,ch in enumerate(chars) }\n","itos = { i:ch for i,ch in enumerate(chars) }\n","encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers\n","decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string\n","\n","chars_str = ''.join(chars)\n","print(f'vocab_size: {vocab_size}')\n","print(f'vocabulary: {chars_str}')\n","\n","# Train and test splits\n","data = torch.tensor(encode(text), dtype=torch.long)\n","n = int(0.9*len(data)) # first 90% will be train, rest val\n","train_data = data[:n]\n","val_data = data[n:]\n","\n","def get_batch(split):\n"," # generate a small batch of data of inputs x and targets y\n"," data = train_data if split == 'train' else val_data\n"," ix = torch.randint(len(data) - T, (B,))\n"," x = torch.stack([data[i:i+T] for i in ix])\n"," y = torch.stack([data[i+1:i+T+1] for i in ix])\n"," x, y = x.to(device), y.to(device)\n"," return x, y\n"]},{"cell_type":"markdown","metadata":{"id":"djRweqFN43In"},"source":["### Head, MHSA"]},{"cell_type":"code","execution_count":5,"metadata":{"executionInfo":{"elapsed":4,"status":"ok","timestamp":1717903570546,"user":{"displayName":"Abarajithan Gnaneswaran","userId":"15065416675145289008"},"user_tz":420},"id":"84SgLLFc4xt6"},"outputs":[],"source":["class Head(nn.Module):\n"," \"\"\" One head of self attention\"\"\"\n","\n"," def __init__(self, Ci, Co):\n"," super().__init__()\n"," self.key = nn.Linear(Ci, Co, bias=False)\n"," self.query = nn.Linear(Ci, Co, bias=False)\n"," self.value = nn.Linear(Ci, Co, bias=False)\n"," self.register_buffer('tril', torch.tril(torch.ones(T, T)))\n","\n"," def forward(self, x):\n"," B, T, Ci = x.shape\n"," '''\n"," B - batch # of independant vectors processed\n"," T - time/block/context # of tokens in a context\n"," Ci - channals/dims input # of features in input\n"," '''\n","\n"," k = self.key(x) # (B,T,Co)\n"," q = self.query(x) # (B,T,Co)\n","\n"," # compute attention scores / affinities\n"," wei = q @ k.transpose(-2,-1) # (B,T,Co) @ (B,Co,T) -> (B,T,T)\n"," wei /= C**0.5 # (B,T,T) scaling, bring variance to 1, to prevent softmax clipping\n"," wei = wei.masked_fill(self.tril[:T,:T]==0, float('-inf')) # (B,T,T) Replace upper triangular of wei with -inf\n"," wei = F.softmax(wei, dim=-1) # (B,T,T) -inf -> 0, rest normalized to 1\n","\n"," v = self.value(x) # (B,T,Co)\n"," out = wei @ v # (B,T,T) @ (B,T,Co) = (B,T,Co)\n","\n"," return out\n","\n","\n","class MultiHeadAttention(nn.Module):\n","\n"," def __init__(self, Ci, H, head_size):\n"," super().__init__()\n"," # 4 heads of 8-dimensional self-attention, for n_embed=32, like a group convolution\n"," self.heads = nn.ModuleList([Head(Ci=Ci, Co=head_size) for _ in range(H)])\n","\n"," def forward(self, x):\n"," x = torch.cat([h(x) for h in self.heads], dim=-1)\n"," return x"]},{"cell_type":"markdown","metadata":{"id":"Om6pUr4N463t"},"source":["### Transformer Block"]},{"cell_type":"code","execution_count":6,"metadata":{"executionInfo":{"elapsed":4,"status":"ok","timestamp":1717903570546,"user":{"displayName":"Abarajithan Gnaneswaran","userId":"15065416675145289008"},"user_tz":420},"id":"2m3aC0tt44sv"},"outputs":[],"source":["class Block(nn.Module):\n"," ''' Transformer block: communication followed by computation '''\n","\n"," def __init__(self, C, H): # C: embedding dimension, H: number of heads\n"," super().__init__()\n"," self.ln1 = nn.LayerNorm(C) # Layernorm along channels (batch & time are batch dims): y = beta + gamma * [x-E(x)]/sqrt(V(x) + ep)\n"," self.sa = MultiHeadAttention(Ci=C, H=H, head_size=C//H)\n"," self.ln2 = nn.LayerNorm(C)\n"," self.ffwd = nn.Sequential( # Feedforward network, so the tokens can \"think about\" what they found in attention.\n"," nn.Linear(C, C*4),\n"," nn.ReLU(),\n"," nn.Linear(C*4, C),\n"," nn.Dropout(dropout),\n"," )\n","\n"," def forward(self, x):\n"," # Residual connections around MSA & FF, to help training\n"," # Note: input without layernorm is added to output\n","\n"," x_skip = x\n","\n"," x = self.ln1(x)\n"," x = self.sa(x) # (B,T,C), Multi head self attention\n"," x = x + x_skip\n","\n"," x = self.ln2(x)\n"," x = self.ffwd(x) # (B,T,C), Per token level. B,T act as batch dimensions\n"," x = x + x_skip\n","\n"," return x\n"]},{"cell_type":"markdown","metadata":{"id":"rEZ6CZrm5ADW"},"source":["### Model"]},{"cell_type":"code","execution_count":7,"metadata":{"executionInfo":{"elapsed":552,"status":"ok","timestamp":1717903571094,"user":{"displayName":"Abarajithan Gnaneswaran","userId":"15065416675145289008"},"user_tz":420},"id":"8B84EGMO49AA"},"outputs":[],"source":["class BigramLanguageModel(nn.Module):\n","\n"," def __init__(self, B,T,C,H,L):\n"," super().__init__()\n"," self.B, self.T, self.C, self.H, self.L = B,T,C,H,L\n"," # each token directly reads off the logits for the next token from a lookup table\n"," self.token_embedding_table = nn.Embedding(vocab_size, C) # for every possible token, weights for next token\n"," self.position_embedding_table = nn.Embedding(T, C)\n","\n"," self.blocks = nn.Sequential(*[Block(C, H) for _ in range(L)])\n"," self.ln_final = nn.LayerNorm(C)\n"," self.lm_head = nn.Linear(C, vocab_size)\n","\n"," def forward(self, idx, targets=None):\n","\n"," tok_emb = self.token_embedding_table(idx) # (B,T,C=n_embed)\n"," pos_emb = self.position_embedding_table(torch.arange(self.T, device=device)) # (T,C): [0,1,2..T-1]\n","\n"," x = tok_emb + pos_emb # (B,T,C)\n"," x = self.blocks(x)\n"," x = self.ln_final(x) # Layernorm applied before last\n"," logits = self.lm_head(x) # (B,T,vocab_size)\n","\n"," if targets is None:\n"," loss = None\n"," else:\n"," B, T, C = logits.shape\n"," logits = logits.view(B*T, C)\n"," targets = targets.view(B*T)\n"," loss = F.cross_entropy(logits, targets)\n","\n"," return logits, loss\n","\n"," def generate(self, idx, max_new_tokens):\n"," for _ in range(max_new_tokens): # idx is (B, T) array of indices in the current context\n"," idx_cond = idx[:, -self.T:] # crop the last block_size tokens for input\n"," logits, loss = self(idx_cond) # get the predictions\n"," logits = logits[:, -1, :] # (B,T,C) -> (B, C)\n"," probs = F.softmax(logits, dim=-1) # (B, C)\n"," idx_next = torch.multinomial(probs, num_samples=1) # sample from the distribution acc to prob (B, 1)\n"," idx = torch.cat((idx, idx_next), dim=1) # New idx is concat (B, T+1)\n"," return idx\n","\n","model = BigramLanguageModel(B,T,C,H,L)\n","m = model.to(device)"]},{"cell_type":"markdown","metadata":{"id":"GqlT4Bwa5Dfo"},"source":["### Training"]},{"cell_type":"code","execution_count":8,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":137344,"status":"ok","timestamp":1717903708435,"user":{"displayName":"Abarajithan Gnaneswaran","userId":"15065416675145289008"},"user_tz":420},"id":"9pDDJiTV5BPa","outputId":"50127402-461f-4d43-ccf7-a88213c2a6f5"},"outputs":[{"name":"stdout","output_type":"stream","text":["step 0: train loss 4.3777, val loss 4.3759\n","step 500: train loss 2.4585, val loss 2.4434\n","step 1000: train loss 2.2973, val loss 2.3072\n","step 1500: train loss 2.2060, val loss 2.2323\n","step 2000: train loss 2.1624, val loss 2.1886\n","step 2500: train loss 2.1125, val loss 2.1715\n","step 3000: train loss 2.1060, val loss 2.1516\n","step 3500: train loss 2.0893, val loss 2.1324\n","step 4000: train loss 2.0550, val loss 2.1122\n","step 4500: train loss 2.0460, val loss 2.1017\n"]}],"source":["@torch.no_grad()\n","def estimate_loss():\n"," out = {}\n"," model.eval()\n"," for split in ['train', 'val']:\n"," losses = torch.zeros(eval_iters)\n"," for k in range(eval_iters):\n"," X, Y = get_batch(split)\n"," logits, loss = model(X, Y)\n"," losses[k] = loss.item()\n"," out[split] = losses.mean()\n"," model.train()\n"," return out\n","\n","optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)\n","\n","for iter in range(max_iters):\n"," if iter % eval_interval == 0: # every once in a while evaluate the loss on train and val sets\n"," losses = estimate_loss()\n"," print(f\"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}\")\n","\n"," xb, yb = get_batch('train') # sample a batch of data\n","\n"," # evaluate the loss\n"," logits, loss = model(xb, yb)\n"," optimizer.zero_grad(set_to_none=True)\n"," loss.backward()\n"," optimizer.step()"]},{"cell_type":"markdown","metadata":{"id":"tLhWVvJR5Hmv"},"source":["### Inference"]},{"cell_type":"code","execution_count":9,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":15131,"status":"ok","timestamp":1717903723560,"user":{"displayName":"Abarajithan Gnaneswaran","userId":"15065416675145289008"},"user_tz":420},"id":"oArZB_RT5FHn","outputId":"7eafa959-c7c7-4dec-effe-c7407d39cbb8"},"outputs":[{"name":"stdout","output_type":"stream","text":[" ast: but thout sed son, be dratentert ome tome witgh:\n","Mbest seecest-Tety Wirwe than will Fredt's in ift I shink,\n","MyovR:\n","I'd I strvers angeight.\n","\n","FO:\n","What madys.\n","\n","BRUCKENG Sater'ffort. though traunt fay ga ame klit- our is rust yarge int Carobnt:\n","ErdenNe I feir that effor stery. hall perper it tame hegpied so sul,\n","As the worth than him.\n","\n","DUKE RENCO:\n","And, prat; loyse in ablourgh\n","Wich pether'd or my protage, my,\n","AndTfousail the ollse sond blee scaradmy you dobqoured hild's core, grenids a'thourdeses ord of-hopcy\n","To creterece to tnate of callion?\n","Heme blivard tell, wat leeg,\n","Hy daye to ret shall,\n","Se asTas Govedl to de werrectan's tot to tRamand my bray all bath I fo jull hyhere to strot: Cuce.\n","\n","PAUSTIU:\n","If thore you were falths, all sme'T cend there truselince,, him hards for olld and lord,\n","Where a they prrostyed, in thuwn, and to sescale hows you hear osild.\n","\n","Sellow the all twat bursed: buthir lapuke west thee probod fas's mrodb so o, and do your, is fit shorvan:\n","And aus chey, woull sto el\n","Etcing, ander.\n","\n","DUKING RIF CLA:\n","Of frounce. VDUGo may frener be wlond:\n","Ten cost.\n","\n","GOUCINCE:\n","ESestnay\n","Whary.\n","Have\n","Stake the, sor loid,\n","And a lord\n","By, madry to! rast as stoe'ced igess'\n","Mave.\n","I qulf thornearppack sais, here, yourd sighit, le this doth aslould hou do coayd ust yourshous, stay,\n","Thhe moreed bears of this round yourdss'l-wikist I cametts of four uny, recose\n","\n","WIO sposs, so, it: holde\n","Myfents in learvance.\n","To woall streses,\n","Nett ther Monce.\n","\n","Sord agly,\n","To hat\n","And ang evined pod path oull cand war welll kvink ig; blow wish shews gar ofser sur.\n","To Geaves:\n","If for crake,\n","Dover arag corander that is's fore hary uself rould I dose mack't,\n","Myill now an's ha;\n","And lothy, tandle, suir sore.\n","\n","DOUCLAP\n","LETAS:\n","AS for, be live, and blong; I yoth of ance'd my litcce, wake wof triow.\n","\n","AUCKK:\n","He zinnot, hut bit wull he force.\n","\n","BLORG OH:\n","Mirsmy jothed couse ber of my that lot.\n","\n","MELENTIENED:\n","My adus beards an:\n","The discad, parl me sue,\n","Nurds fids might seintle,\n","Do faun'st,\n","Seel?\n","\n","TidTOS I ure so \n"]}],"source":["context = torch.ones((1, T), dtype=torch.long, device=device) # start with '\\n\\n\\n\\n' as seed\n","out_ints = m.generate(context, max_new_tokens=2000)[0].tolist() # output list of ints\n","print(decode(out_ints))"]},{"cell_type":"code","execution_count":9,"metadata":{"executionInfo":{"elapsed":30,"status":"ok","timestamp":1717903723561,"user":{"displayName":"Abarajithan Gnaneswaran","userId":"15065416675145289008"},"user_tz":420},"id":"1fNpRB1u5IuS"},"outputs":[],"source":[]}],"metadata":{"accelerator":"GPU","colab":{"authorship_tag":"ABX9TyN3lgNBz/+JS9jKeXYyYuPq","gpuType":"T4","provenance":[]},"kernelspec":{"display_name":"Python 3","name":"python3"},"language_info":{"name":"python"}},"nbformat":4,"nbformat_minor":0}