{"cells":[{"cell_type":"markdown","metadata":{"id":"YHMlh4BG265g"},"source":["# Multiple Heads, Feedforward Layer\n","\n","\n"," \n","
\n"," \"Open\n","
\n","\n","
\n","
"]},{"cell_type":"code","execution_count":1,"metadata":{"executionInfo":{"elapsed":2177,"status":"ok","timestamp":1717898457278,"user":{"displayName":"Abarajithan Gnaneswaran","userId":"15065416675145289008"},"user_tz":420},"id":"_W4_1tox12Nd"},"outputs":[],"source":["import torch\n","import torch.nn as nn\n","from torch.nn import functional as F"]},{"cell_type":"markdown","metadata":{"id":"WsCZ7yRJ2P4a"},"source":["### Hyperparameters"]},{"cell_type":"code","execution_count":2,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":8,"status":"ok","timestamp":1717898457278,"user":{"displayName":"Abarajithan Gnaneswaran","userId":"15065416675145289008"},"user_tz":420},"id":"Pl4urkMq2Jt4","outputId":"7ff0ca0c-4586-41da-c5a4-852584afcb38"},"outputs":[{"data":{"text/plain":[""]},"execution_count":2,"metadata":{},"output_type":"execute_result"}],"source":["B = 32 # B: how many independent sequences will we process in parallel?\n","T = 8 # T: what is the maximum context length for predictions?\n","C = 32 # C: numer of different features analysed (also D = dims)\\\n","H = 4 # H: number of attention heads\n","\n","max_iters = 5000\n","eval_interval = 500\n","learning_rate = 1e-3\n","device = 'cuda' if torch.cuda.is_available() else 'cpu'\n","eval_iters = 200\n","torch.manual_seed(1337)"]},{"cell_type":"markdown","metadata":{"id":"9IQCbe5c2T8_"},"source":["### Data Loading"]},{"cell_type":"code","execution_count":3,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":483,"status":"ok","timestamp":1717898457757,"user":{"displayName":"Abarajithan Gnaneswaran","userId":"15065416675145289008"},"user_tz":420},"id":"8yq7Fp822Z20","outputId":"aae1e486-7b14-4473-f28c-46016db1d1cb"},"outputs":[{"name":"stdout","output_type":"stream","text":["--2024-06-09 02:00:55-- https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt\n","Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.111.133, 185.199.109.133, 185.199.108.133, ...\n","Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.111.133|:443... connected.\n","HTTP request sent, awaiting response... 200 OK\n","Length: 1115394 (1.1M) [text/plain]\n","Saving to: ‘input.txt.1’\n","\n","\rinput.txt.1 0%[ ] 0 --.-KB/s \rinput.txt.1 100%[===================>] 1.06M --.-KB/s in 0.06s \n","\n","2024-06-09 02:00:55 (18.1 MB/s) - ‘input.txt.1’ saved [1115394/1115394]\n","\n"]}],"source":["!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"]},{"cell_type":"code","execution_count":4,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":521,"status":"ok","timestamp":1717898458272,"user":{"displayName":"Abarajithan Gnaneswaran","userId":"15065416675145289008"},"user_tz":420},"id":"K4o3cGHs2R1b","outputId":"c1d1dc20-4a37-4198-e077-99efe365b27d"},"outputs":[{"name":"stdout","output_type":"stream","text":["vocab_size: 65\n","vocabulary: \n"," !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz\n"]}],"source":["with open('input.txt', 'r', encoding='utf-8') as f:\n"," text = f.read()\n","\n","# here are all the unique characters that occur in this text\n","chars = sorted(list(set(text)))\n","vocab_size = len(chars)\n","# create a mapping from characters to integers\n","stoi = { ch:i for i,ch in enumerate(chars) }\n","itos = { i:ch for i,ch in enumerate(chars) }\n","encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers\n","decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string\n","\n","chars_str = ''.join(chars)\n","print(f'vocab_size: {vocab_size}')\n","print(f'vocabulary: {chars_str}')\n","\n","# Train and test splits\n","data = torch.tensor(encode(text), dtype=torch.long)\n","n = int(0.9*len(data)) # first 90% will be train, rest val\n","train_data = data[:n]\n","val_data = data[n:]\n","\n","def get_batch(split):\n"," # generate a small batch of data of inputs x and targets y\n"," data = train_data if split == 'train' else val_data\n"," ix = torch.randint(len(data) - T, (B,))\n"," x = torch.stack([data[i:i+T] for i in ix])\n"," y = torch.stack([data[i+1:i+T+1] for i in ix])\n"," x, y = x.to(device), y.to(device)\n"," return x, y\n"]},{"cell_type":"markdown","metadata":{"id":"Kw2L86zl2kDM"},"source":["### Attention Head"]},{"cell_type":"code","execution_count":5,"metadata":{"executionInfo":{"elapsed":6,"status":"ok","timestamp":1717898458273,"user":{"displayName":"Abarajithan Gnaneswaran","userId":"15065416675145289008"},"user_tz":420},"id":"Xl_-Mm5J2eRe"},"outputs":[],"source":["class Head(nn.Module):\n"," \"\"\" One head of self attention\"\"\"\n","\n"," def __init__(self, Ci, Co):\n"," super().__init__()\n"," self.key = nn.Linear(Ci, Co, bias=False)\n"," self.query = nn.Linear(Ci, Co, bias=False)\n"," self.value = nn.Linear(Ci, Co, bias=False)\n"," self.register_buffer('tril', torch.tril(torch.ones(T, T)))\n","\n"," def forward(self, x):\n"," B, T, Ci = x.shape\n"," '''\n"," B - batch # of independant vectors processed\n"," T - time/block/context # of tokens in a context\n"," Ci - channals/dims input # of features in input\n"," '''\n","\n"," k = self.key(x) # (B,T,Co)\n"," q = self.query(x) # (B,T,Co)\n","\n"," # compute attention scores / affinities\n"," wei = q @ k.transpose(-2,-1) # (B,T,Co) @ (B,Co,T) -> (B,T,T)\n"," wei /= C**0.5 # (B,T,T) scaling, bring variance to 1, to prevent softmax clipping\n"," wei = wei.masked_fill(self.tril[:T,:T]==0, float('-inf')) # (B,T,T) Replace upper triangular of wei with -inf\n"," wei = F.softmax(wei, dim=-1) # (B,T,T) -inf -> 0, rest normalized to 1\n","\n"," v = self.value(x) # (B,T,Co)\n"," out = wei @ v # (B,T,T) @ (B,T,Co) = (B,T,Co)\n","\n"," return out"]},{"cell_type":"markdown","metadata":{"id":"qIyVuM562lqz"},"source":["# Model"]},{"cell_type":"code","execution_count":6,"metadata":{"executionInfo":{"elapsed":5,"status":"ok","timestamp":1717898458273,"user":{"displayName":"Abarajithan Gnaneswaran","userId":"15065416675145289008"},"user_tz":420},"id":"UywxTAaO2pmA"},"outputs":[],"source":["class BigramLanguageModel(nn.Module):\n","\n"," def __init__(self, B,T,C,H):\n"," super().__init__()\n"," self.B, self.T, self.C, self.H = B,T,C,H\n"," # each token directly reads off the logits for the next token from a lookup table\n"," self.token_embedding_table = nn.Embedding(vocab_size, C) # for every possible token, weights for next token\n"," self.position_embedding_table = nn.Embedding(T, C)\n","\n"," self.heads = nn.ModuleList([Head(Ci=C, Co=C//H) for _ in range(H)]) # 4 heads of 8-dimensional self-attention, for n_embed=32, like a group convolution\n"," self.ffwd = nn.Sequential( # Feedforward network, so the tokens can \"think about\" what they found in attention.\n"," nn.Linear(C, C),\n"," nn.ReLU(),\n"," )\n"," self.lm_head = nn.Linear(C, vocab_size)\n","\n"," def forward(self, idx, targets=None):\n","\n"," tok_emb = self.token_embedding_table(idx) # (B,T,C=n_embed)\n"," pos_emb = self.position_embedding_table(torch.arange(self.T, device=device)) # (T,C): [0,1,2..T-1]\n","\n"," x = tok_emb + pos_emb # (B,T,C)\n","\n"," x = torch.cat([h(x) for h in self.heads], dim=-1) # (B,T,C), Multi head self attention\n"," x = self.ffwd(x) # (B,T,C), Per token level. B,T act as batch dimensions\n"," logits = self.lm_head(x) # (B,T,vocab_size)\n","\n"," if targets is None:\n"," loss = None\n"," else:\n"," B, T, C = logits.shape\n"," logits = logits.view(B*T, C)\n"," targets = targets.view(B*T)\n"," loss = F.cross_entropy(logits, targets)\n","\n"," return logits, loss\n","\n"," def generate(self, idx, max_new_tokens):\n"," for _ in range(max_new_tokens): # idx is (B, T) array of indices in the current context\n"," idx_cond = idx[:, -self.T:] # crop the last block_size tokens for input\n"," logits, loss = self(idx_cond) # get the predictions\n"," logits = logits[:, -1, :] # (B,T,C) -> (B, C)\n"," probs = F.softmax(logits, dim=-1) # (B, C)\n"," idx_next = torch.multinomial(probs, num_samples=1) # sample from the distribution acc to prob (B, 1)\n"," idx = torch.cat((idx, idx_next), dim=1) # New idx is concat (B, T+1)\n"," return idx\n","\n","model = BigramLanguageModel(B,T,C,H)\n","m = model.to(device)"]},{"cell_type":"markdown","metadata":{"id":"539In3l_2x5x"},"source":["### Training"]},{"cell_type":"code","execution_count":7,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":70828,"status":"ok","timestamp":1717898529096,"user":{"displayName":"Abarajithan Gnaneswaran","userId":"15065416675145289008"},"user_tz":420},"id":"FI9KGtkg2r-m","outputId":"0066554d-167f-4562-fbfa-caf215aac3e0"},"outputs":[{"name":"stdout","output_type":"stream","text":["step 0: train loss 4.1996, val loss 4.1995\n","step 500: train loss 2.5993, val loss 2.6077\n","step 1000: train loss 2.4629, val loss 2.4651\n","step 1500: train loss 2.3974, val loss 2.3951\n","step 2000: train loss 2.3297, val loss 2.3470\n","step 2500: train loss 2.3018, val loss 2.3221\n","step 3000: train loss 2.2828, val loss 2.2936\n","step 3500: train loss 2.2495, val loss 2.2721\n","step 4000: train loss 2.2435, val loss 2.2468\n","step 4500: train loss 2.2286, val loss 2.2411\n"]}],"source":["@torch.no_grad()\n","def estimate_loss():\n"," out = {}\n"," model.eval()\n"," for split in ['train', 'val']:\n"," losses = torch.zeros(eval_iters)\n"," for k in range(eval_iters):\n"," X, Y = get_batch(split)\n"," logits, loss = model(X, Y)\n"," losses[k] = loss.item()\n"," out[split] = losses.mean()\n"," model.train()\n"," return out\n","\n","optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)\n","\n","for iter in range(max_iters):\n"," if iter % eval_interval == 0: # every once in a while evaluate the loss on train and val sets\n"," losses = estimate_loss()\n"," print(f\"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}\")\n","\n"," xb, yb = get_batch('train') # sample a batch of data\n","\n"," # evaluate the loss\n"," logits, loss = model(xb, yb)\n"," optimizer.zero_grad(set_to_none=True)\n"," loss.backward()\n"," optimizer.step()"]},{"cell_type":"markdown","metadata":{"id":"2pXYHufc20K-"},"source":["### Inference"]},{"cell_type":"code","execution_count":8,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":1001,"status":"ok","timestamp":1717898530085,"user":{"displayName":"Abarajithan Gnaneswaran","userId":"15065416675145289008"},"user_tz":420},"id":"7GCkCuin2znP","outputId":"53178f74-4726-466d-b5e3-59c9f1c5ea5b"},"outputs":[{"name":"stdout","output_type":"stream","text":[" Ba hill ind vexe\n","Whe the hot mes fin.\n","\n","Cy hirad I four shat son yald hat lods guk- have ave lithr\n","GLOull\n","Wllld, with.\n","\n","BANTAUCHAR:\n","I ork sak's willl eger aepale ganed ay wouce\n","song thy in noduace being uliths tot upiord--mard meme the fles,\n","Tharr, leanven-ty's thy on it in weancepte, digus I the souts pof ther cork, if woth' that he the bre tolf,\n","An live:\n","Mrins,\n","What,\n","Whorend;\n","Ant to mal ry fa of thaw a tes tir, to tis argito-, setolf brntens?\n","I' sit mem, thand yought is an he frifperend tepefur\n"]}],"source":["context = torch.ones((1, T), dtype=torch.long, device=device) # start with '\\n\\n\\n\\n' as seed\n","out_ints = m.generate(context, max_new_tokens=500)[0].tolist() # output list of ints\n","print(decode(out_ints))"]},{"cell_type":"code","execution_count":8,"metadata":{"executionInfo":{"elapsed":2,"status":"ok","timestamp":1717898530085,"user":{"displayName":"Abarajithan Gnaneswaran","userId":"15065416675145289008"},"user_tz":420},"id":"KnGm-yJiuOn2"},"outputs":[],"source":[]}],"metadata":{"colab":{"authorship_tag":"ABX9TyM48kxaQuJKaPVzaI+5cNX8","provenance":[]},"kernelspec":{"display_name":"Python 3","name":"python3"},"language_info":{"name":"python"}},"nbformat":4,"nbformat_minor":0}