Self Attention#

Open In Colab


Attention is a communication mechanism#

  • A directed graph of T nodes, each node being a token position, and contains info as a vector of H size

  • T nodes aggregate infromation as a weighted sum from all nodes

  • Data dependant, the data stored in the nodes change with time

  • For autogression,

    • 1-th node gets only from itself

    • 2-nd node gets from 1,2

    • T-th node gets from everyone

  • Nodes have no notion of space / ordering. So we need to add postional embedding

Encoder vs Decoder#

  • Encoder

    • no triangular mask (it gathers data from both directions)

    • eg. translation, sentiment analysis

  • Decoder

    • triangular mask (gathers data from past & present only)

    • predicts next word

    • “autoregressive”, P(next_word) = P(this_word|past_words) * P(prev_word|words_before)…

Attention = V*softmax(QK.T/sqrt(H))#

  • Self-Attention: Q,K,V come from X

  • Cross-attention:

    • query from x, keys & values come from different place.

    • eg: English -> French, french (query) searches in English (key, value)

x       - private information of token (B,T,C)
q       - each token generates a query vector , [I am a vowel in position 8 looking for consonents upto position 4]
k       - each other token generates a key vector, what information I have [I am a consonent in position 3]
w=qk    - affinity - those two tokens find each other, affinity at the intersection will be very high (I am interested in these positions)
v       - value vector, what information I am willing to provide
y=(w) - accumulate all the information from interested positions to me (B,T,H)

import torch
import torch.nn as nn
from torch.nn import functional as F

Hyperparameters#

B = 32 # B: how many independent sequences will we process in parallel?
T = 8  # T: what is the maximum context length for predictions?
C = 32 # C: numer of different features analysed (D)
max_iters = 5000
eval_interval = 500
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
torch.manual_seed(1337)
<torch._C.Generator at 0x7b019efb6210>

Data#

!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
--2024-06-09 01:52:09--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt.1’

input.txt.1         100%[===================>]   1.06M  5.18MB/s    in 0.2s    

2024-06-09 01:52:10 (5.18 MB/s) - ‘input.txt.1’ saved [1115394/1115394]
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

# here are all the unique characters that occur in this text
chars = sorted(list(set(text)))
vocab_size = len(chars)
# create a mapping from characters to integers
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string

chars_str = ''.join(chars)
print(f'vocab_size: {vocab_size}')
print(f'vocabulary: {chars_str}')

# Train and test splits
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9*len(data)) # first 90% will be train, rest val
train_data = data[:n]
val_data = data[n:]
vocab_size: 65
vocabulary: 
 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
def get_batch(split):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - T, (B,))
    x = torch.stack([data[i:i+T] for i in ix])
    y = torch.stack([data[i+1:i+T+1] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y

Self Attention#

class Head(nn.Module):
    """ One head of self attention"""

    def __init__(self, Ci, Co):
        super().__init__()
        self.Ci, self.Co = Ci, Co
        self.key   = nn.Linear(Ci, Co, bias=False)
        self.query = nn.Linear(Ci, Co, bias=False)
        self.value = nn.Linear(Ci, Co, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(T, T)))

    def forward(self, x):
        B, T, Ci  = x.shape
        '''
        B  - batch               # of independant vectors processed
        T  - time/block/context  # of tokens in a context
        Ci - channals/dims input # of features in input
        '''

        k = self.key(x)   # (B,T,Co)
        q = self.query(x) # (B,T,Co)

        # compute attention scores / affinities
        wei = q @ k.transpose(-2,-1)                                 # (B,T,Co) @ (B,Co,T) -> (B,T,T)
        wei /= C**0.5                                                # (B,T,T) scaling, bring variance to 1, to prevent softmax clipping
        wei  = wei.masked_fill(self.tril[:T,:T]==0, float('-inf'))   # (B,T,T) Replace upper triangular of wei with -inf
        wei  = F.softmax(wei, dim=-1)                                # (B,T,T) -inf -> 0, rest normalized to 1

        v = self.value(x)  # (B,T,Co)
        out = wei @ v      # (B,T,T) @ (B,T,Co) = (B,T,Co)

        return out

Model: Self attention + Linear + pos/token embeddings#

class BigramLanguageModel(nn.Module):

    def __init__(self, B,T,C):
        super().__init__()
        self.B, self.T, self.C = B,T,C
        # each token directly reads off the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, C) # for every possible token, weights for next token
        self.position_embedding_table = nn.Embedding(T, C)

        self.sa_head = Head(Ci=C, Co=C)
        self.lm_head = nn.Linear(C, vocab_size)

    def forward(self, idx, targets=None):

        tok_emb = self.token_embedding_table(idx)                               # (B,T,C)
        pos_emb = self.position_embedding_table(torch.arange(self.T, device=device)) # (T,C): [0,1,2..T-1]

        x = tok_emb + pos_emb     # (B,T,Ci)
        x = self.sa_head(x)       # (B,T,Co) Apply self attention
        logits = self.lm_head(x)  # (B,T,vocab_size)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):                        # idx is (B, T) array of indices in the current context
            idx_cond = idx[:, -T:]                    # crop the last block_size tokens for input
            logits, loss = self(idx_cond)                      # get the predictions
            logits = logits[:, -1, :]                          # (B,T,C) -> (B, C)
            probs = F.softmax(logits, dim=-1)                  # (B, C)
            idx_next = torch.multinomial(probs, num_samples=1) # sample from the distribution acc to prob (B, 1)
            idx = torch.cat((idx, idx_next), dim=1)            # New idx is concat (B, T+1)
        return idx

model = BigramLanguageModel(B,T,C)
m = model.to(device)

Training#

@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

for iter in range(max_iters):
    if iter % eval_interval == 0:   # every once in a while evaluate the loss on train and val sets
        losses = estimate_loss()
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

    xb, yb = get_batch('train')     # sample a batch of data

    # evaluate the loss
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
step 0: train loss 4.2000, val loss 4.2047
step 500: train loss 2.6911, val loss 2.7087
step 1000: train loss 2.5196, val loss 2.5303
step 1500: train loss 2.4775, val loss 2.4829
step 2000: train loss 2.4408, val loss 2.4523
step 2500: train loss 2.4272, val loss 2.4435
step 3000: train loss 2.4130, val loss 2.4327
step 3500: train loss 2.3956, val loss 2.4212
step 4000: train loss 2.4041, val loss 2.3992
step 4500: train loss 2.3980, val loss 2.4084

Output#

context = torch.ones((1, T), dtype=torch.long, device=device)  # start with '\n\n\n\n' as seed
out_ints = m.generate(context, max_new_tokens=500)[0].tolist() # output list of ints
print(decode(out_ints))
        pes le isen.
Woto teven INGO, ous into CYedd shou maithe ert thethens the the del ede cksy ow? Wlouby aicecat tisall wor
G'imemonou mar ee hacreancad hontrt had wousk ucavere.

Baraghe lfousto beme,
S m; ten gh;
S:
Ano ice de bay alysathef beatireplim serbeais I fard
Sy,
Me hallil:
DWAR: us,
Wte hse aecathate, parrise in hr'd pat
ERY:
Bf bul walde betl'ts I yshore grest atre ciak aloo; wo fart hets atl.

That at Wh kear ben.
 hend.

KTh'd foushe d'l otacaengs p bloul blod arme foot buthes fo boe