{"cells":[{"cell_type":"markdown","metadata":{"id":"akR2_VyoxTYz"},"source":["# Bigram Model with Linear layer & Token + Positional embeddings\n","\n","
\n","
\n"," \n","
\n","
\n","\n"," \n"," \n","\n","* Context window (block size) = 8\n","* This model looks at 8 past tokens to predicts 1 future token\n","* Two trainable embedding tables at input.\n","* Token embedding table maps each token into a vector of size (32), giving a (8, 32) matrix\n","* Position embedding table maps the positions 0-31 into a vector of size (32)\n","* These are added together to informs the model of the positions of input tokens"]},{"cell_type":"code","execution_count":1,"metadata":{"executionInfo":{"elapsed":2164,"status":"ok","timestamp":1717897128421,"user":{"displayName":"Abarajithan Gnaneswaran","userId":"15065416675145289008"},"user_tz":420},"id":"gYNbxJVcvzFo"},"outputs":[],"source":["import torch\n","import torch.nn as nn\n","from torch.nn import functional as F"]},{"cell_type":"markdown","metadata":{"id":"jkFKSiBnw0UA"},"source":["### Hyperparameters"]},{"cell_type":"code","execution_count":2,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":285,"status":"ok","timestamp":1717897128705,"user":{"displayName":"Abarajithan Gnaneswaran","userId":"15065416675145289008"},"user_tz":420},"id":"bDhmURoVw3mR","outputId":"ef49b530-9f94-45c0-acfe-21349dea584b"},"outputs":[{"data":{"text/plain":[""]},"execution_count":2,"metadata":{},"output_type":"execute_result"}],"source":["B = 32 # B (batch size): how many independent sequences will we process in parallel?\n","T = 1 # T (block size): what is the maximum context length for predictions?\n","C = 32 # C (channels) : dimensionality, also called d\n","max_iters = 3000\n","eval_interval = 300\n","learning_rate = 1e-2\n","device = 'cuda' if torch.cuda.is_available() else 'cpu'\n","eval_iters = 200\n","torch.manual_seed(1337)"]},{"cell_type":"markdown","metadata":{"id":"TulFKc4Kw5mQ"},"source":["### Dataset"]},{"cell_type":"code","execution_count":3,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":5,"status":"ok","timestamp":1717897128705,"user":{"displayName":"Abarajithan Gnaneswaran","userId":"15065416675145289008"},"user_tz":420},"id":"_5We5NZJw4RN","outputId":"d5fd4948-28f8-4fc1-aeeb-09347bca0cb2"},"outputs":[{"name":"stdout","output_type":"stream","text":["--2024-06-09 01:38:46-- https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt\n","Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.110.133, 185.199.108.133, 185.199.109.133, ...\n","Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.110.133|:443... connected.\n","HTTP request sent, awaiting response... 200 OK\n","Length: 1115394 (1.1M) [text/plain]\n","Saving to: ‘input.txt.4’\n","\n","\rinput.txt.4 0%[ ] 0 --.-KB/s \rinput.txt.4 100%[===================>] 1.06M --.-KB/s in 0.06s \n","\n","2024-06-09 01:38:46 (18.6 MB/s) - ‘input.txt.4’ saved [1115394/1115394]\n","\n"]}],"source":["!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"]},{"cell_type":"code","execution_count":4,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":715,"status":"ok","timestamp":1717897129418,"user":{"displayName":"Abarajithan Gnaneswaran","userId":"15065416675145289008"},"user_tz":420},"id":"24w7twrTw80Q","outputId":"7dfd5b49-c1c8-43c0-f817-35cdd62e569e"},"outputs":[{"name":"stdout","output_type":"stream","text":["vocab_size: 65\n","vocabulary: \n"," !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz\n"]}],"source":["with open('input.txt', 'r', encoding='utf-8') as f:\n"," text = f.read()\n","\n","# here are all the unique characters that occur in this text\n","chars = sorted(list(set(text)))\n","vocab_size = len(chars)\n","# create a mapping from characters to integers\n","stoi = { ch:i for i,ch in enumerate(chars) }\n","itos = { i:ch for i,ch in enumerate(chars) }\n","encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers\n","decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string\n","\n","chars_str = ''.join(chars)\n","print(f'vocab_size: {vocab_size}')\n","print(f'vocabulary: {chars_str}')\n","\n","# Train and test splits\n","data = torch.tensor(encode(text), dtype=torch.long)\n","n = int(0.9*len(data)) # first 90% will be train, rest val\n","train_data = data[:n]\n","val_data = data[n:]\n","\n","\n","def get_batch(split):\n"," # generate a small batch of data of inputs x and targets y\n"," data = train_data if split == 'train' else val_data\n"," ix = torch.randint(len(data) - T, (T,))\n"," x = torch.stack([data[i:i+T] for i in ix])\n"," y = torch.stack([data[i+1:i+T+1] for i in ix])\n"," x, y = x.to(device), y.to(device)\n"," return x, y"]},{"cell_type":"markdown","metadata":{"id":"T1KRDZ05xBNf"},"source":["### Bigram Model with Linear layer & token/pos embeddings"]},{"cell_type":"code","execution_count":5,"metadata":{"executionInfo":{"elapsed":3,"status":"ok","timestamp":1717897129419,"user":{"displayName":"Abarajithan Gnaneswaran","userId":"15065416675145289008"},"user_tz":420},"id":"QWgCkVhCw_ED"},"outputs":[],"source":["class BigramLanguageModel(nn.Module):\n","\n"," def __init__(self, B, T ,C):\n"," super().__init__()\n"," self.B, self.T, self.C = B, T, C\n"," # each token directly reads off the logits for the next token from a lookup table\n"," self.token_embedding_table = nn.Embedding(vocab_size, C) # for every possible token, weights for next token\n"," self.position_embedding_table = nn.Embedding(B, C)\n"," self.lm_head = nn.Linear(C, vocab_size)\n","\n"," def forward(self, idx, targets=None):\n","\n"," tok_emb = self.token_embedding_table(idx) # (B,T,C)\n"," pos_emb = self.position_embedding_table(torch.arange(self.T, device=device)) # (T,C): [0,1,2..T-1]\n","\n"," '''\n"," B - batch # of independant vectors processed\n"," T - time/block/context # of tokens in a context\n"," C - channels # of features\n"," '''\n","\n"," x = tok_emb + pos_emb # (B,T,C)\n"," logits = self.lm_head(x) # (B,T,vocab_size)\n","\n"," if targets is None:\n"," loss = None\n"," else:\n"," B, T, C = logits.shape\n"," logits = logits.view(B*T, C)\n"," targets = targets.view(B*T)\n"," loss = F.cross_entropy(logits, targets)\n","\n"," return logits, loss\n","\n"," def generate(self, idx, max_new_tokens):\n"," for _ in range(max_new_tokens): # idx is (B, T) array of indices in the current context\n"," idx_cond = idx[:, -self.B:] # crop the last block_size tokens for input\n"," logits, loss = self(idx_cond) # get the predictions\n"," logits = logits[:, -1, :] # (B,T,C) -> (B, C)\n"," probs = F.softmax(logits, dim=-1) # (B, C)\n"," idx_next = torch.multinomial(probs, num_samples=1) # sample from the distribution acc to prob (B, 1)\n"," idx = torch.cat((idx, idx_next), dim=1) # New idx is concat (B, T+1)\n"," return idx\n","\n","model = BigramLanguageModel(B, T ,C)\n","m = model.to(device)\n"]},{"cell_type":"markdown","metadata":{"id":"Xygt1VN3xPn4"},"source":["### Training"]},{"cell_type":"code","execution_count":6,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":11969,"status":"ok","timestamp":1717897141386,"user":{"displayName":"Abarajithan Gnaneswaran","userId":"15065416675145289008"},"user_tz":420},"id":"uniujgg0xIK8","outputId":"c715a0c3-7cb6-4a91-c4a7-61e39a2ecab6"},"outputs":[{"name":"stdout","output_type":"stream","text":["step 0: train loss 4.3337, val loss 4.4599\n","step 300: train loss 3.2863, val loss 3.2765\n","step 600: train loss 3.2917, val loss 3.0822\n","step 900: train loss 3.1424, val loss 3.0526\n","step 1200: train loss 3.0914, val loss 3.0241\n","step 1500: train loss 3.2469, val loss 2.9623\n","step 1800: train loss 2.8173, val loss 2.8259\n","step 2100: train loss 2.8679, val loss 2.9978\n","step 2400: train loss 2.9930, val loss 3.0948\n","step 2700: train loss 2.9934, val loss 2.9553\n"]}],"source":["@torch.no_grad()\n","def estimate_loss():\n"," out = {}\n"," model.eval()\n"," for split in ['train', 'val']:\n"," losses = torch.zeros(eval_iters)\n"," for k in range(eval_iters):\n"," X, Y = get_batch(split)\n"," logits, loss = model(X, Y)\n"," losses[k] = loss.item()\n"," out[split] = losses.mean()\n"," model.train()\n"," return out\n","\n","optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)\n","\n","for iter in range(max_iters):\n"," if iter % eval_interval == 0: # every once in a while evaluate the loss on train and val sets\n"," losses = estimate_loss()\n"," print(f\"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}\")\n","\n"," xb, yb = get_batch('train') # sample a batch of data\n","\n"," # evaluate the loss\n"," logits, loss = model(xb, yb)\n"," optimizer.zero_grad(set_to_none=True)\n"," loss.backward()\n"," optimizer.step()\n"]},{"cell_type":"markdown","metadata":{"id":"4A7Rj_VIxLvE"},"source":["### Inference"]},{"cell_type":"code","execution_count":7,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":8,"status":"ok","timestamp":1717897141387,"user":{"displayName":"Abarajithan Gnaneswaran","userId":"15065416675145289008"},"user_tz":420},"id":"Zdo1zwSaxKRQ","outputId":"4e5fdcd8-8771-4db5-c18a-1f0319d7e17d"},"outputs":[{"name":"stdout","output_type":"stream","text":[" l ouerketisis ir alutor a magay per: d;\n","O ke he tugolouce, painllo b wer oououeeffor is\n","SBe t;\n","Su Ase llur iethe'ereerd'louce h, hous he eaver d'lofmoaousethe ghe b r w!owhaisoo us r impouooreldoug nes ce heaigubusuinisouer e he\n","O!oonere hedisor ?\n","Se.\n","Y,\n","Thed g aepthas dSisit y t bofon t ldour , hofos fofout ae out forveaayo hesoooouct touy mowigr DugheUWr we ouowther\n","HAs ce wenowisoAt wer louthe ige he , tout ce tt Mhe dldfowers\n","koubofouofain her cthiner towem se oferug \n","S:\n","SThe t in he his \n"]}],"source":["context = torch.ones((1, B), dtype=torch.long, device=device) # start with '\\n\\n\\n\\n' as seed\n","out_ints = m.generate(context, max_new_tokens=500)[0].tolist() # output list of ints\n","print(decode(out_ints))"]},{"cell_type":"code","execution_count":7,"metadata":{"executionInfo":{"elapsed":5,"status":"ok","timestamp":1717897141387,"user":{"displayName":"Abarajithan Gnaneswaran","userId":"15065416675145289008"},"user_tz":420},"id":"yRyXud57xSCX"},"outputs":[],"source":[]}],"metadata":{"colab":{"authorship_tag":"ABX9TyMsbp8JVpaMMBpKLb83pYbG","provenance":[]},"kernelspec":{"display_name":"Python 3","name":"python3"},"language_info":{"name":"python"}},"nbformat":4,"nbformat_minor":0}