{"cells":[{"cell_type":"markdown","metadata":{"id":"phqq9JnA3gnF"},"source":["# Residual Connections, Transformer Block\n","\n","\n"," \n","
\n"," \"Open\n","
\n","\n","
\n","
"]},{"cell_type":"code","execution_count":1,"metadata":{"executionInfo":{"elapsed":8823,"status":"ok","timestamp":1717898474497,"user":{"displayName":"Abarajithan Gnaneswaran","userId":"15065416675145289008"},"user_tz":420},"id":"VXa2SxSb3PAr"},"outputs":[],"source":["import torch\n","import torch.nn as nn\n","from torch.nn import functional as F"]},{"cell_type":"markdown","metadata":{"id":"ltuvAjX53nbW"},"source":["### Hyperparameters"]},{"cell_type":"code","execution_count":2,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":7,"status":"ok","timestamp":1717898474497,"user":{"displayName":"Abarajithan Gnaneswaran","userId":"15065416675145289008"},"user_tz":420},"id":"a4ScTF2n3mRG","outputId":"51dac004-3708-4807-ede5-598d6fa4a07d"},"outputs":[{"data":{"text/plain":[""]},"execution_count":2,"metadata":{},"output_type":"execute_result"}],"source":["B = 32 # B: how many independent sequences will we process in parallel?\n","T = 8 # T: what is the maximum context length for predictions?\n","C = 32 # C: numer of different features analysed (also D = dims)\n","H = 4 # H: number of attention heads\n","max_iters = 5000\n","eval_interval = 500\n","learning_rate = 1e-3\n","device = 'cuda' if torch.cuda.is_available() else 'cpu'\n","eval_iters = 200\n","torch.manual_seed(1337)"]},{"cell_type":"markdown","metadata":{"id":"5xvMg4Py3qRQ"},"source":["### Data"]},{"cell_type":"code","execution_count":3,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":711,"status":"ok","timestamp":1717898475204,"user":{"displayName":"Abarajithan Gnaneswaran","userId":"15065416675145289008"},"user_tz":420},"id":"PChCZJ9o3o3c","outputId":"67c79dc0-ed36-4f60-b07a-76a2f6230dea"},"outputs":[{"name":"stdout","output_type":"stream","text":["--2024-06-09 02:01:12-- https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt\n","Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...\n","Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.\n","HTTP request sent, awaiting response... 200 OK\n","Length: 1115394 (1.1M) [text/plain]\n","Saving to: ‘input.txt’\n","\n","input.txt 100%[===================>] 1.06M --.-KB/s in 0.09s \n","\n","2024-06-09 02:01:13 (12.0 MB/s) - ‘input.txt’ saved [1115394/1115394]\n","\n"]}],"source":["!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"]},{"cell_type":"code","execution_count":4,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":1684,"status":"ok","timestamp":1717898482023,"user":{"displayName":"Abarajithan Gnaneswaran","userId":"15065416675145289008"},"user_tz":420},"id":"FqW3PoPe3sXU","outputId":"5b65f688-00b6-416b-9869-0b74f32dae29"},"outputs":[{"name":"stdout","output_type":"stream","text":["vocab_size: 65\n","vocabulary: \n"," !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz\n"]}],"source":["with open('input.txt', 'r', encoding='utf-8') as f:\n"," text = f.read()\n","\n","# here are all the unique characters that occur in this text\n","chars = sorted(list(set(text)))\n","vocab_size = len(chars)\n","# create a mapping from characters to integers\n","stoi = { ch:i for i,ch in enumerate(chars) }\n","itos = { i:ch for i,ch in enumerate(chars) }\n","encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers\n","decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string\n","\n","chars_str = ''.join(chars)\n","print(f'vocab_size: {vocab_size}')\n","print(f'vocabulary: {chars_str}')\n","\n","# Train and test splits\n","data = torch.tensor(encode(text), dtype=torch.long)\n","n = int(0.9*len(data)) # first 90% will be train, rest val\n","train_data = data[:n]\n","val_data = data[n:]\n","\n","def get_batch(split):\n"," # generate a small batch of data of inputs x and targets y\n"," data = train_data if split == 'train' else val_data\n"," ix = torch.randint(len(data) - T, (B,))\n"," x = torch.stack([data[i:i+T] for i in ix])\n"," y = torch.stack([data[i+1:i+T+1] for i in ix])\n"," x, y = x.to(device), y.to(device)\n"," return x, y\n"]},{"cell_type":"markdown","metadata":{"id":"h8JlDyWE31Fu"},"source":["### Head, MHSA"]},{"cell_type":"code","execution_count":16,"metadata":{"executionInfo":{"elapsed":215,"status":"ok","timestamp":1717899661513,"user":{"displayName":"Abarajithan Gnaneswaran","userId":"15065416675145289008"},"user_tz":420},"id":"5saA9sMB3ste"},"outputs":[],"source":["class Head(nn.Module):\n"," \"\"\" One head of self attention\"\"\"\n","\n"," def __init__(self, Ci, Co):\n"," super().__init__()\n"," self.key = nn.Linear(Ci, Co, bias=False)\n"," self.query = nn.Linear(Ci, Co, bias=False)\n"," self.value = nn.Linear(Ci, Co, bias=False)\n"," self.register_buffer('tril', torch.tril(torch.ones(T, T)))\n","\n"," def forward(self, x):\n"," B, T, Ci = x.shape\n"," '''\n"," B - batch # of independant vectors processed\n"," T - time/block/context # of tokens in a context\n"," Ci - channals/dims input # of features in input\n"," '''\n","\n"," k = self.key(x) # (B,T,Co)\n"," q = self.query(x) # (B,T,Co)\n","\n"," # compute attention scores / affinities\n"," wei = q @ k.transpose(-2,-1) # (B,T,Co) @ (B,Co,T) -> (B,T,T)\n"," wei /= C**0.5 # (B,T,T) scaling, bring variance to 1, to prevent softmax clipping\n"," wei = wei.masked_fill(self.tril[:T,:T]==0, float('-inf')) # (B,T,T) Replace upper triangular of wei with -inf\n"," wei = F.softmax(wei, dim=-1) # (B,T,T) -inf -> 0, rest normalized to 1\n","\n"," v = self.value(x) # (B,T,Co)\n"," out = wei @ v # (B,T,T) @ (B,T,Co) = (B,T,Co)\n","\n"," return out\n","\n","\n","class MultiHeadAttention(nn.Module):\n","\n"," def __init__(self, Ci, H, head_size):\n"," super().__init__()\n"," # 4 heads of 8-dimensional self-attention, for n_embed=32, like a group convolution\n"," self.heads = nn.ModuleList([Head(Ci=Ci, Co=head_size) for _ in range(H)])\n","\n"," def forward(self, x):\n"," x = torch.cat([h(x) for h in self.heads], dim=-1)\n"," return x"]},{"cell_type":"markdown","metadata":{"id":"uxGx1U9i320N"},"source":["### Transformer Block"]},{"cell_type":"code","execution_count":17,"metadata":{"executionInfo":{"elapsed":3,"status":"ok","timestamp":1717899662940,"user":{"displayName":"Abarajithan Gnaneswaran","userId":"15065416675145289008"},"user_tz":420},"id":"ZxzUwHVw32Cy"},"outputs":[],"source":["class Block(nn.Module):\n"," ''' Transformer block: communication followed by computation '''\n","\n"," def __init__(self, C, H): # C: embedding dimension, H: number of heads\n"," super().__init__()\n"," self.sa = MultiHeadAttention(Ci=C, H=H, head_size=C//H)\n"," self.ffwd = nn.Sequential( # Feedforward network, so the tokens can \"think about\" what they found in attention.\n"," nn.Linear(C, C),\n"," nn.ReLU(),\n"," )\n","\n"," def forward(self, x):\n"," # Residual connections around MSA & FF, to help training\n"," x = x + self.sa(x) # (B,T,C), Multi head self attention\n"," x = x + self.ffwd(x) # (B,T,C), Per token level. B,T act as batch dimensions\n"," return x\n"]},{"cell_type":"markdown","metadata":{"id":"YL3eUDOI4BWw"},"source":["### Model"]},{"cell_type":"code","execution_count":18,"metadata":{"executionInfo":{"elapsed":231,"status":"ok","timestamp":1717899665830,"user":{"displayName":"Abarajithan Gnaneswaran","userId":"15065416675145289008"},"user_tz":420},"id":"8X9VOrEb4EuM"},"outputs":[],"source":["class BigramLanguageModel(nn.Module):\n","\n"," def __init__(self, B,T,C,H):\n"," super().__init__()\n"," self.B, self.T, self.C, self.H = B,T,C,H\n"," # each token directly reads off the logits for the next token from a lookup table\n"," self.token_embedding_table = nn.Embedding(vocab_size, C) # for every possible token, weights for next token\n"," self.position_embedding_table = nn.Embedding(T, C)\n","\n"," self.blocks = nn.Sequential(\n"," Block(C, H),\n"," Block(C, H),\n"," Block(C, H),\n"," )\n"," self.lm_head = nn.Linear(C, vocab_size)\n","\n"," def forward(self, idx, targets=None):\n","\n"," tok_emb = self.token_embedding_table(idx) # (B,T,C)\n"," pos_emb = self.position_embedding_table(torch.arange(self.T, device=device)) # (T,C): [0,1,2..T-1]\n","\n"," x = tok_emb + pos_emb # (B,T,C)\n"," x = self.blocks(x)\n"," logits = self.lm_head(x) # (B,T,vocab_size)\n","\n"," if targets is None:\n"," loss = None\n"," else:\n"," B, T, C = logits.shape\n"," logits = logits.view(B*T, C)\n"," targets = targets.view(B*T)\n"," loss = F.cross_entropy(logits, targets)\n","\n"," return logits, loss\n","\n"," def generate(self, idx, max_new_tokens):\n"," for _ in range(max_new_tokens): # idx is (B, T) array of indices in the current context\n"," idx_cond = idx[:, -self.T:] # crop the last block_size tokens for input\n"," logits, loss = self(idx_cond) # get the predictions\n"," logits = logits[:, -1, :] # (B,T,C) -> (B, C)\n"," probs = F.softmax(logits, dim=-1) # (B, C)\n"," idx_next = torch.multinomial(probs, num_samples=1) # sample from the distribution acc to prob (B, 1)\n"," idx = torch.cat((idx, idx_next), dim=1) # New idx is concat (B, T+1)\n"," return idx\n","\n","model = BigramLanguageModel(B,T,C,H)\n","m = model.to(device)"]},{"cell_type":"markdown","metadata":{"id":"_0SaadNR4Nmp"},"source":["### Training"]},{"cell_type":"code","execution_count":19,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":100218,"status":"ok","timestamp":1717899774525,"user":{"displayName":"Abarajithan Gnaneswaran","userId":"15065416675145289008"},"user_tz":420},"id":"Py9g8Qii4I9Q","outputId":"dbeb536d-02b3-4dc8-fe1c-2af00f6547ca"},"outputs":[{"name":"stdout","output_type":"stream","text":["step 0: train loss 4.8060, val loss 4.8283\n","step 500: train loss 2.5139, val loss 2.5350\n","step 1000: train loss 2.3694, val loss 2.3802\n","step 1500: train loss 2.2825, val loss 2.3045\n","step 2000: train loss 2.2332, val loss 2.2590\n","step 2500: train loss 2.1949, val loss 2.2323\n","step 3000: train loss 2.1624, val loss 2.2134\n","step 3500: train loss 2.1452, val loss 2.1947\n","step 4000: train loss 2.1305, val loss 2.1673\n","step 4500: train loss 2.1216, val loss 2.1589\n"]}],"source":["@torch.no_grad()\n","def estimate_loss():\n"," out = {}\n"," model.eval()\n"," for split in ['train', 'val']:\n"," losses = torch.zeros(eval_iters)\n"," for k in range(eval_iters):\n"," X, Y = get_batch(split)\n"," logits, loss = model(X, Y)\n"," losses[k] = loss.item()\n"," out[split] = losses.mean()\n"," model.train()\n"," return out\n","\n","optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)\n","\n","for iter in range(max_iters):\n"," if iter % eval_interval == 0: # every once in a while evaluate the loss on train and val sets\n"," losses = estimate_loss()\n"," print(f\"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}\")\n","\n"," xb, yb = get_batch('train') # sample a batch of data\n","\n"," # evaluate the loss\n"," logits, loss = model(xb, yb)\n"," optimizer.zero_grad(set_to_none=True)\n"," loss.backward()\n"," optimizer.step()\n"]},{"cell_type":"markdown","metadata":{"id":"HuwAGjBm4Qsk"},"source":["### Inference"]},{"cell_type":"code","execution_count":20,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":1852,"status":"ok","timestamp":1717899777067,"user":{"displayName":"Abarajithan Gnaneswaran","userId":"15065416675145289008"},"user_tz":420},"id":"KzC4XfNw4OsS","outputId":"62acc058-c3ec-43fb-8f4d-c9cdf34e63d2"},"outputs":[{"name":"stdout","output_type":"stream","text":["\n","\n","\n","\n","\n","\n","\n","\n","HENRIIA:\n","Hin sher wall, so evall of Bon is their, as wery' to leals of Searswe't galied to fall; one ble eltannelles!\n","BoLIV:\n","And in gense for many of them.\n","Henges a song\n","Whim mor:\n","Cand and of ulse scausnall.\n","\n","Lourd Sourg; leave of rach damperd!\n","I me; my is of riethord soucnsew, an your mitcancase,\n","Ad the nuh you, one is hamer come,\n","And mige tis so with yeaing mild I shoull,\n","O shithe coneres deukn!\n","Why decuck or hereecase! VI more. He I'll hand hanwin my the we of a of youu bult them havince on n\n"]}],"source":["context = torch.zeros((1, T), dtype=torch.long, device=device) # start with '\\n\\n\\n\\n' as seed\n","out_ints = m.generate(context, max_new_tokens=500)[0].tolist() # output list of ints\n","print(decode(out_ints))"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"S09Hn-vA4Sk1"},"outputs":[],"source":[]}],"metadata":{"colab":{"authorship_tag":"ABX9TyMftOpNBybHZ2+N24VPF6cV","provenance":[]},"kernelspec":{"display_name":"Python 3","name":"python3"},"language_info":{"name":"python"}},"nbformat":4,"nbformat_minor":0}