{"cells":[{"cell_type":"markdown","metadata":{"id":"40Z834mwz8sd"},"source":["# Self Attention\n","\n","\n"," \n","
\n"," \"Open\n","
\n","\n","
\n","
\n","\n","### Attention is a communication mechanism\n","\n","- A directed graph of T nodes, each node being a token position, and contains info as a vector of H size\n","- T nodes aggregate infromation as a weighted sum from all nodes\n","- Data dependant, the data stored in the nodes change with time\n","- For autogression,\n"," - 1-th node gets only from itself\n"," - 2-nd node gets from 1,2\n"," - T-th node gets from everyone\n","- Nodes have no notion of space / ordering. So we need to add postional embedding\n","\n","\n","### Encoder vs Decoder\n","\n","- Encoder\n"," - no triangular mask (it gathers data from both directions)\n"," - eg. translation, sentiment analysis\n","- Decoder\n"," - triangular mask (gathers data from past & present only)\n"," - predicts next word\n"," - \"autoregressive\", P(next_word) = P(this_word|past_words) * P(prev_word|words_before)...\n","\n","### Attention = V*softmax(QK.T/sqrt(H))\n","\n","- Self-Attention: Q,K,V come from X\n","- Cross-attention:\n"," - query from x, keys & values come from different place.\n"," - eg: English -> French, french (query) searches in English (key, value)\n","\n","\n","```\n","x - private information of token (B,T,C)\n","q - each token generates a query vector , [I am a vowel in position 8 looking for consonents upto position 4]\n","k - each other token generates a key vector, what information I have [I am a consonent in position 3]\n","w=qk - affinity - those two tokens find each other, affinity at the intersection will be very high (I am interested in these positions)\n","v - value vector, what information I am willing to provide\n","y=vσ(w) - accumulate all the information from interested positions to me (B,T,H)\n","\n","```"]},{"cell_type":"code","execution_count":1,"metadata":{"executionInfo":{"elapsed":2811,"status":"ok","timestamp":1717897931792,"user":{"displayName":"Abarajithan Gnaneswaran","userId":"15065416675145289008"},"user_tz":420},"id":"LguvQN1lzMZG"},"outputs":[],"source":["import torch\n","import torch.nn as nn\n","from torch.nn import functional as F"]},{"cell_type":"markdown","metadata":{"id":"aWHlyhrEzofP"},"source":["### Hyperparameters"]},{"cell_type":"code","execution_count":2,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":7,"status":"ok","timestamp":1717897931792,"user":{"displayName":"Abarajithan Gnaneswaran","userId":"15065416675145289008"},"user_tz":420},"id":"VcBYQcNmzWOP","outputId":"c3ad6418-d9ea-4d88-b003-05a0752ea7db"},"outputs":[{"data":{"text/plain":[""]},"execution_count":2,"metadata":{},"output_type":"execute_result"}],"source":["B = 32 # B: how many independent sequences will we process in parallel?\n","T = 8 # T: what is the maximum context length for predictions?\n","C = 32 # C: numer of different features analysed (D)\n","max_iters = 5000\n","eval_interval = 500\n","learning_rate = 1e-3\n","device = 'cuda' if torch.cuda.is_available() else 'cpu'\n","eval_iters = 200\n","torch.manual_seed(1337)"]},{"cell_type":"markdown","metadata":{"id":"tB3jqsjRzqus"},"source":["### Data"]},{"cell_type":"code","execution_count":3,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":936,"status":"ok","timestamp":1717897932724,"user":{"displayName":"Abarajithan Gnaneswaran","userId":"15065416675145289008"},"user_tz":420},"id":"mleT4wJMza4T","outputId":"a8182fd7-9575-4eb4-c8f5-49cade8356f4"},"outputs":[{"name":"stdout","output_type":"stream","text":["--2024-06-09 01:52:09-- https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt\n","Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...\n","Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.\n","HTTP request sent, awaiting response... 200 OK\n","Length: 1115394 (1.1M) [text/plain]\n","Saving to: ‘input.txt.1’\n","\n","input.txt.1 100%[===================>] 1.06M 5.18MB/s in 0.2s \n","\n","2024-06-09 01:52:10 (5.18 MB/s) - ‘input.txt.1’ saved [1115394/1115394]\n","\n"]}],"source":["!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"]},{"cell_type":"code","execution_count":4,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":6,"status":"ok","timestamp":1717897932724,"user":{"displayName":"Abarajithan Gnaneswaran","userId":"15065416675145289008"},"user_tz":420},"id":"vac9PsfhzX1O","outputId":"ec6aba82-0df4-4186-9b18-4ad9179b134e"},"outputs":[{"name":"stdout","output_type":"stream","text":["vocab_size: 65\n","vocabulary: \n"," !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz\n"]}],"source":["with open('input.txt', 'r', encoding='utf-8') as f:\n"," text = f.read()\n","\n","# here are all the unique characters that occur in this text\n","chars = sorted(list(set(text)))\n","vocab_size = len(chars)\n","# create a mapping from characters to integers\n","stoi = { ch:i for i,ch in enumerate(chars) }\n","itos = { i:ch for i,ch in enumerate(chars) }\n","encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers\n","decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string\n","\n","chars_str = ''.join(chars)\n","print(f'vocab_size: {vocab_size}')\n","print(f'vocabulary: {chars_str}')\n","\n","# Train and test splits\n","data = torch.tensor(encode(text), dtype=torch.long)\n","n = int(0.9*len(data)) # first 90% will be train, rest val\n","train_data = data[:n]\n","val_data = data[n:]"]},{"cell_type":"code","execution_count":5,"metadata":{"executionInfo":{"elapsed":4,"status":"ok","timestamp":1717897932724,"user":{"displayName":"Abarajithan Gnaneswaran","userId":"15065416675145289008"},"user_tz":420},"id":"P2TbGEzczbPj"},"outputs":[],"source":["def get_batch(split):\n"," # generate a small batch of data of inputs x and targets y\n"," data = train_data if split == 'train' else val_data\n"," ix = torch.randint(len(data) - T, (B,))\n"," x = torch.stack([data[i:i+T] for i in ix])\n"," y = torch.stack([data[i+1:i+T+1] for i in ix])\n"," x, y = x.to(device), y.to(device)\n"," return x, y"]},{"cell_type":"markdown","metadata":{"id":"h1XWxvxVzslZ"},"source":["### Self Attention"]},{"cell_type":"code","execution_count":6,"metadata":{"executionInfo":{"elapsed":3,"status":"ok","timestamp":1717897932724,"user":{"displayName":"Abarajithan Gnaneswaran","userId":"15065416675145289008"},"user_tz":420},"id":"PHJYD8BOzdrR"},"outputs":[],"source":["class Head(nn.Module):\n"," \"\"\" One head of self attention\"\"\"\n","\n"," def __init__(self, Ci, Co):\n"," super().__init__()\n"," self.Ci, self.Co = Ci, Co\n"," self.key = nn.Linear(Ci, Co, bias=False)\n"," self.query = nn.Linear(Ci, Co, bias=False)\n"," self.value = nn.Linear(Ci, Co, bias=False)\n"," self.register_buffer('tril', torch.tril(torch.ones(T, T)))\n","\n"," def forward(self, x):\n"," B, T, Ci = x.shape\n"," '''\n"," B - batch # of independant vectors processed\n"," T - time/block/context # of tokens in a context\n"," Ci - channals/dims input # of features in input\n"," '''\n","\n"," k = self.key(x) # (B,T,Co)\n"," q = self.query(x) # (B,T,Co)\n","\n"," # compute attention scores / affinities\n"," wei = q @ k.transpose(-2,-1) # (B,T,Co) @ (B,Co,T) -> (B,T,T)\n"," wei /= C**0.5 # (B,T,T) scaling, bring variance to 1, to prevent softmax clipping\n"," wei = wei.masked_fill(self.tril[:T,:T]==0, float('-inf')) # (B,T,T) Replace upper triangular of wei with -inf\n"," wei = F.softmax(wei, dim=-1) # (B,T,T) -inf -> 0, rest normalized to 1\n","\n"," v = self.value(x) # (B,T,Co)\n"," out = wei @ v # (B,T,T) @ (B,T,Co) = (B,T,Co)\n","\n"," return out"]},{"cell_type":"markdown","metadata":{"id":"WBqL5hnWzvyC"},"source":["### Model: Self attention + Linear + pos/token embeddings"]},{"cell_type":"code","execution_count":7,"metadata":{"executionInfo":{"elapsed":3,"status":"ok","timestamp":1717897932724,"user":{"displayName":"Abarajithan Gnaneswaran","userId":"15065416675145289008"},"user_tz":420},"id":"lsskJhTlzf3P"},"outputs":[],"source":["class BigramLanguageModel(nn.Module):\n","\n"," def __init__(self, B,T,C):\n"," super().__init__()\n"," self.B, self.T, self.C = B,T,C\n"," # each token directly reads off the logits for the next token from a lookup table\n"," self.token_embedding_table = nn.Embedding(vocab_size, C) # for every possible token, weights for next token\n"," self.position_embedding_table = nn.Embedding(T, C)\n","\n"," self.sa_head = Head(Ci=C, Co=C)\n"," self.lm_head = nn.Linear(C, vocab_size)\n","\n"," def forward(self, idx, targets=None):\n","\n"," tok_emb = self.token_embedding_table(idx) # (B,T,C)\n"," pos_emb = self.position_embedding_table(torch.arange(self.T, device=device)) # (T,C): [0,1,2..T-1]\n","\n"," x = tok_emb + pos_emb # (B,T,Ci)\n"," x = self.sa_head(x) # (B,T,Co) Apply self attention\n"," logits = self.lm_head(x) # (B,T,vocab_size)\n","\n"," if targets is None:\n"," loss = None\n"," else:\n"," B, T, C = logits.shape\n"," logits = logits.view(B*T, C)\n"," targets = targets.view(B*T)\n"," loss = F.cross_entropy(logits, targets)\n","\n"," return logits, loss\n","\n"," def generate(self, idx, max_new_tokens):\n"," for _ in range(max_new_tokens): # idx is (B, T) array of indices in the current context\n"," idx_cond = idx[:, -T:] # crop the last block_size tokens for input\n"," logits, loss = self(idx_cond) # get the predictions\n"," logits = logits[:, -1, :] # (B,T,C) -> (B, C)\n"," probs = F.softmax(logits, dim=-1) # (B, C)\n"," idx_next = torch.multinomial(probs, num_samples=1) # sample from the distribution acc to prob (B, 1)\n"," idx = torch.cat((idx, idx_next), dim=1) # New idx is concat (B, T+1)\n"," return idx\n","\n","model = BigramLanguageModel(B,T,C)\n","m = model.to(device)"]},{"cell_type":"markdown","metadata":{"id":"Hq7Kr1jAz2Zi"},"source":["### Training"]},{"cell_type":"code","execution_count":8,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":44477,"status":"ok","timestamp":1717897977198,"user":{"displayName":"Abarajithan Gnaneswaran","userId":"15065416675145289008"},"user_tz":420},"id":"unru0WuUzhq6","outputId":"8af2f0dd-f56b-4696-8225-165ed531729b"},"outputs":[{"name":"stdout","output_type":"stream","text":["step 0: train loss 4.2000, val loss 4.2047\n","step 500: train loss 2.6911, val loss 2.7087\n","step 1000: train loss 2.5196, val loss 2.5303\n","step 1500: train loss 2.4775, val loss 2.4829\n","step 2000: train loss 2.4408, val loss 2.4523\n","step 2500: train loss 2.4272, val loss 2.4435\n","step 3000: train loss 2.4130, val loss 2.4327\n","step 3500: train loss 2.3956, val loss 2.4212\n","step 4000: train loss 2.4041, val loss 2.3992\n","step 4500: train loss 2.3980, val loss 2.4084\n"]}],"source":["@torch.no_grad()\n","def estimate_loss():\n"," out = {}\n"," model.eval()\n"," for split in ['train', 'val']:\n"," losses = torch.zeros(eval_iters)\n"," for k in range(eval_iters):\n"," X, Y = get_batch(split)\n"," logits, loss = model(X, Y)\n"," losses[k] = loss.item()\n"," out[split] = losses.mean()\n"," model.train()\n"," return out\n","\n","optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)\n","\n","for iter in range(max_iters):\n"," if iter % eval_interval == 0: # every once in a while evaluate the loss on train and val sets\n"," losses = estimate_loss()\n"," print(f\"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}\")\n","\n"," xb, yb = get_batch('train') # sample a batch of data\n","\n"," # evaluate the loss\n"," logits, loss = model(xb, yb)\n"," optimizer.zero_grad(set_to_none=True)\n"," loss.backward()\n"," optimizer.step()"]},{"cell_type":"markdown","metadata":{"id":"WsTMTwHWz3_a"},"source":["### Output"]},{"cell_type":"code","execution_count":9,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":394,"status":"ok","timestamp":1717897977589,"user":{"displayName":"Abarajithan Gnaneswaran","userId":"15065416675145289008"},"user_tz":420},"id":"xDbBAhiczjiw","outputId":"45b3206d-7ec3-4541-f04e-60cf6ef9afc3"},"outputs":[{"name":"stdout","output_type":"stream","text":[" pes le isen.\n","Woto teven INGO, ous into CYedd shou maithe ert thethens the the del ede cksy ow? Wlouby aicecat tisall wor\n","G'imemonou mar ee hacreancad hontrt had wousk ucavere.\n","\n","Baraghe lfousto beme,\n","S m; ten gh;\n","S:\n","Ano ice de bay alysathef beatireplim serbeais I fard\n","Sy,\n","Me hallil:\n","DWAR: us,\n","Wte hse aecathate, parrise in hr'd pat\n","ERY:\n","Bf bul walde betl'ts I yshore grest atre ciak aloo; wo fart hets atl.\n","\n","That at Wh kear ben.\n"," hend.\n","\n","KTh'd foushe d'l otacaengs p bloul blod arme foot buthes fo boe\n"]}],"source":["context = torch.ones((1, T), dtype=torch.long, device=device) # start with '\\n\\n\\n\\n' as seed\n","out_ints = m.generate(context, max_new_tokens=500)[0].tolist() # output list of ints\n","print(decode(out_ints))"]},{"cell_type":"code","execution_count":9,"metadata":{"executionInfo":{"elapsed":3,"status":"ok","timestamp":1717897977589,"user":{"displayName":"Abarajithan Gnaneswaran","userId":"15065416675145289008"},"user_tz":420},"id":"F8_Klsp3zmg8"},"outputs":[],"source":[]}],"metadata":{"colab":{"authorship_tag":"ABX9TyMxlqaSNreESy1aP7XvdB7/","provenance":[]},"kernelspec":{"display_name":"Python 3","name":"python3"},"language_info":{"name":"python"}},"nbformat":4,"nbformat_minor":0}