Transformers (Decoder-Only)
In the post, I will try to explain the transformer architecture (decoder-only) from scratch, so lets begin.
High level overview
- Given a prompt, it is first tokenized and embedded with position. This flows into a block.
- There are multiple blocks. Each block consists of Multi-Head Attention layers, followed by layer norms, followed by feed-forward network.
- After the blocks, there is a linear and softmax linear, which outputs probabilites of the token should come next.
Input data
Lets come back to encoding the input later, but lets assume after embedding with position, the input is as follows,
inp = (2, 6, 4)
"""
where,
2 is the batch size
6 is the context length
4 is the embedding dimension
"""
"""
Let this be the dataset
The sun dipped below the horizon, painting the sky with hues of orange and pink.
A gentle breeze rustled the leaves, creating a soothing melody.
In that peaceful moment, the world seemed to pause and breathe.
"""
# Lets assume the input is as follows
inp = [
[
[0.4553, 0.3277, 0.4210, 0.6628], # The
[0.0874, 0.3216, 0.2850, 0.1438], # sun
[0.8880, 0.2221, 0.6271, 0.8234], # dipped
[0.9180, 0.8070, 0.4281, 0.1977], # below
[0.4874, 0.9018, 0.4258, 0.1630], # the
[0.0441, 0.1988, 0.6751, 0.8757], # horizon
],
[
[0.9347, 0.9255, 0.6341, 0.0567], # ,
[0.7656, 0.2911, 0.6161, 0.0123], # painting
[0.4200, 0.3295, 0.1863, 0.7694], # the
[0.3197, 0.9724, 0.6066, 0.2184], # sky
[0.4521, 0.9276, 0.3951, 0.1281], # with
[0.7899, 0.7894, 0.2245, 0.2889] # hues
]
]
"""
We have 2 batches, each with 6 rows (these are words from the dataset,
since that is our context length) and 4 columns (the length of our
embedding dimension).
"""
This input now gets fed into individual blocks which are independent of each other.
Block
The block consists of multiple attention heads, layer normalization layers, and a feed-forward layer.
- The input is first passed through a layer normalization layer. (In the original paper, layer normalization is applied after the attention heads, but recently, it has become more common to apply it before the attention heads.)
- The normalized input is then fed into the individual attention heads.
- The outputs from the multi-head attention are added to the original input.
- The result of this addition is passed through another layer normalization layer.
- The normalized output is then fed into the feed-forward layer.
- Finally, the output of the feed-forward layer is added to the output from step 3.
Multi-Attention heads
- The input is fed into multiple attention heads in parallel.
- The outputs from each attention head are concatenated.
- The concatenated output is then passed through a linear layer.
Single Attention head (Attention mechanism)
First lets understand what is the goal of attention mechanism, what does it try to achieve,
- Decide what parts of the input sequence (the context length) to focus on when processing or generating next token.
For our example, lets take the first input sequence, The sun dipped below the horizon
(6 is our context length). Let’s say we are predicting the last word The sun dipped below the ____
, to predict this word, it needs to learn that the word sun
or dipped
is more important or relevant compared to others, and give more attention towards those words.
To accomplish this, there are 3 vectors that are used, query, key and value. The intuition is as follows
Name | Intuition |
---|---|
Query (Q) | The current token in question |
Key (K) | Each token’s relevance wrt the query token |
Value (V) | The actual representation of the tokens |
Each token in the sequence has got all 3 vectors associated with them. The way they are derived is by shifting or projecting them from embedding space into a query, key and value space using a linear transformation nn.linear(bias=False)
. The weights associated with this linear layer are learnt during training. In other words, the model tries to learn a good weight matrix that can transform the input embedding into reasonable representations of the query, key and value space for the given dataset. The reason this is done is because, say the word apple is used in a sentence, based on the context, we can tell if the word apple is referring to the fruit or the company, however, in the embedding space, the word apple has got 1 fixed representation.
We get the query, key and value matrix like so
query = nn.Linear(n_embd, head_size, bias=False)
key = nn.Linear(n_embd, head_size, bias=False)
value = nn.Linear(n_embd, head_size, bias=False)
where n_embd
is 4. head_size
is calculated as follows
head_size = n_embd // n_head
n_head
is a parameter we can choose based on the embedding dimension, for example, if the embedding dimension is 4, we can set n_head
to 2, so that head_size
is 2. Therefore,
query = nn.Linear(4, 2, bias=False) # query.weight.shape = (2, 4)
key = nn.Linear(4, 2, bias=False) # key.weight.shape = (2, 4)
value = nn.Linear(4, 2, bias=False) # value.weight.shape = (2, 4)
The input that is fed into each head remains the same, ie, (2, 6, 4). So lets feed this into the attention head and see what happens
inp = (2, 6, 4)
# first, we get the respective query, key and value matrix by projecting
# them into the new space
q = query(inp) # inp @ query.weight.T = (2, 6, 4) @ (4, 2) = (2, 6, 2)
k = key(inp) # inp @ key.weight.T = (2, 6, 4) @ (4, 2) = (2, 6, 2)
v = value(inp) # inp @ value.weigth.T = (2, 6, 4) @ (4, 2) = (2, 6, 2)
# second, we take the dot product (matmul) between the queries and keys
r = q @ k.transpose(-2, -1) # (2, 6, 2) @ (2, 2, 6) = (2, 6, 6)
# third, we scale it by the square root of the head_size
# (hence the name scaled dot-product attention)
r = r * k.shape[-1]**-0.5 # (2, 6, 6)
# fourth, we apply causal mask
tril = torch.tril(torch.ones(6, 6))
# (2, 6, 6)
r = r.masked_fill(tril[:inp.shape[1], :inp.shape[1]] == 0, float('-inf'))
# fifth, we normalize it using softmax
r = F.softmax(r, dim=-1) # (2, 6, 6)
# sixth, perform weighted sum wrt the values
out = r @ v # (2, 6, 6) @ (2, 6, 2) = (2, 6, 2)
Lets see what is happening with 1 token say horizon
,
# (2, 6, 4)
inp = [
[
[], # the
[], # sun
[], # dipped
[], # below
[], # the
[0.0441, 0.1988, 0.6751, 0.8757], # horizon
],
[
...
]
]
# q = query(inp)
# q = (2, 6, 2) -> each head
q = [
[
[], # the
[], # sun
[], # dipped
[], # below
[], # the
[0.9100, 0.3448], # horizon
],
[
...
]
]
# k = key(inp)
# k = (2, 6, 2) -> each head
k = [
[
[0.0921, 0.9907], # the
[0.5637, 0.7303], # sun
[0.1860, 0.4071], # dipped
[0.8067, 0.1776], # below
[0.7002, 0.6632], # the
[0.9094, 0.3594] # horizon
],
[
...
]
]
# v = value(inp)
# v = (2, 6, 2) -> each head
v = [
[
[0.5637, 0.4056], # the
[0.9803, 0.0100], # sun
[0.4111, 0.3980], # dipped
[0.6882, 0.9797], # below
[0.5551, 0.7583], # the
[0.3060, 0.2141], # horizon
],
[
...
]
]
The token horizon
has now shifted/projected to a new query, key and value space. The query matrix as the name suggests, is trying to query other tokens in the sequence and ask each of them which one of you are relevant to me ? The key matrix contains the answer to this question. Remember that these are all vectors in n-dim space, when we take a dot product between 2 vectors, it signifies how close those 2 vectors are or in other words if they point in the same direction. So when we take the dot product between the query and key vectors, the scalar output tells us how much one token in the sequence (key) is related to the token in question (query).
Lets remove the BS dim to make things simpler, so now we have
# (6, 2)
q = [
[], # the
[], # sun
[], # dipped
[], # below
[], # the
[0.9100, 0.3448], # horizon
]
# (2, 6)
k.T = [
# the # sun # dipped # below # the # horizon
[0.0921, 0.5637, 0.1860, 0.8067, 0.7002, 0.9094],
[0.9907, 0.7303, 0.4071, 0.1776, 0.6632, 0.3594]
]
# now when we do q @ k.T, we are taking the dot product between
# the horizon token query vector and all other tokens key vectors
# r (6, 6) = q @ k.T (6, 2) @ (2, 6)
"""
the sun dipped below the horizon
the
sun
dipped
below
the
horizon 0.4254, 0.7648, 0.3096, 0.7953, 0.8659, 0.9515
"""
The (6, 6) matrix we get tells us how much each token in the sequence is relevant to the query token (horizon in this case). In other words how much the query and key vector point in the same direction.
r = r * k.shape[-1]**-0.5
We now divide the matrix by square root of head_size
, this is done to because if head_size
is large then the dot product values become large, therefore, we scale them by the square root of head_size
. It also helps in the next step, when we perform softmax.
Now we apply masking to the attention matrix. This means the current token can only look and learn from itself and the tokens that comes before it, not after it. This makes sense, since we are trying to predict the next token in the sequence. This is done by masking out the upper triangle of the matrix.
"""
the sun dipped below the horizon
the -inf -inf -inf -inf -inf
sun -inf -inf -inf -inf
dipped -inf -inf -inf
below -inf -inf
the -inf
horizon 0.4254 0.7648 0.3096 0.7953 0.8659 0.9515
"""
# (2, 6, 6)
tril = torch.tril(torch.ones(6, 6))
r = r.masked_fill(tril[:inp.shape[1], :inp.shape[1]] == 0, float('-inf'))
Next we apply softmax along the last dim to convert the raw attention scores into a probability distribution so that all rows sum to 1. This will be useful in the next step when we want to weight each value vector.
r = F.softmax(r, dim=-1)
# it can be seen that, higher attention scores get higher values
# and all of them sum to 1
"""
the sun dipped below the horizon
the 0 0 0 0 0
sun 0 0 0 0
dipped 0 0 0
below 0 0
the 0
horizon 0.1252 0.1758 0.1115 0.1812 0.1945 0.2119
"""
This matrix tells us how much how much weightage we need to give to other tokens wrt the query token horizon
, or in the other words, it tells us how much attention horizon
token should pay to every other token. To do this we can just matmul r
with the value
matrix, remember that the value
matrix contains the actual content of the sequence,
# (6, 2) = (6, 6) @ (6, 2)
out = r @ v
"""
the sun dipped below the horizon
the 0 0 0 0 0
sun 0 0 0 0
dipped 0 0 0
below 0 0
the 0
horizon 0.1252 0.1758 0.1115 0.1812 0.1945 0.2119
@
[
[0.5637, 0.4056], # the
[0.9803, 0.0100], # sun
[0.4111, 0.3980], # dipped
[0.6882, 0.9797], # below
[0.5551, 0.7583], # the
[0.3060, 0.2141], # horizon
]
out
[
[] # the
[] # sun
[] # dipped
[] # below
[] # the
[0.5863, 0.4673] # horizon
]
"""
Therefore, the final output blends together context from all tokens, weighted by their relevance.
Great, so remember that all this is done for a single head, but there are multiple heads that run in parallel, once the attention mechanisim is complete, we concatenate them along the last dim
out = torch.cat([h(x) for h in heads], dim=-1)
"""
In our example, we have 2 heads, and as we have seen in the previous step,
the output of a single head is of shape (2, 6, 2), therefore when we concatenate
2 heads along the last dim, we get the final output shape as (2, 6, 4).
"""
The output now gets passed to a linear layer
# if we have divided the embedding dimension equally, then the linear layer
# can just be (n_emdb, n_embd)
# nn.Linear(2*2, 4)
proj = nn.Linear(head_size * num_heads, n_embd)
# (2, 6, 4) = (2, 6, 4) @ (4, 4)
out = self.proj(out)
Each head captures or learn different aspects of the data, for example, one head might learn grammer context, and another head might learn time context, etc. When we concatenate them, we are just stacking them, but to improve the learning a linear layer is used.
Next, the output is added with the input. This is called as residual connections or skip connections. Basically, these help in optimization during backpropogation, because, in very deep networks such as transformers, the gradient can become very small as they flow backward which hinders the learning process. One way to solve this is by adding the input to the output after some layers. We know that the addition node in an autograd graph distributes the gradient equally to both of its input, hence keeping the gradient alive till it reaches the input.
# therefore we finally have
# (2, 6, 4) = (2, 6, 4) + (2, 6, 4)
x = x + self_attention(layer_norm1(x))
Now the output goes through another layernorm and a feed-forward network. The feed-forward network is a simple network with the hidden layer size being 4 times the embedding dimension. The expression 4 times the embeddding dimension comes from the original paper.
class FeedFoward(nn.Module):
def __init__(self, n_embd):
super().__init__()
self.net = nn.Sequential(
nn.Linear(n_embd, 4 * n_embd),
nn.ReLU(),
nn.Linear(4 * n_embd, n_embd),
)
def forward(self, x):
return self.net(x)
ffwd = FeedFoward(4)
# (2, 6, 4) + (((2, 6, 4) @ (4, 16)) @ (16, 4))
# (2, 6, 4) + ((2, 6, 16) @ (16, 4))
# (2, 6, 4) + ((2, 6, 4))
# (2, 6, 4)
x = x + ffwd(layer_norm1(x))
So we have finally finished processing 1 block. Now, there are n blocks that run sequentially one after the other, the ouput of 1 block is the input to the second block.
blocks = nn.Sequential(*[Block() for _ in range(n_blocks)])
# (2, 6, 4)
x = blocks(x)
Although the blocks are processed sequentially, the multi-attention heads are processed parallelly. Once this is complete, we pass the output to a last layernorm and then through a last feed-forward network that procjects the output from the embedding space to the vocabulary space.
ln_f = nn.LayerNorm(n_embd)
lm_head = nn.Linear(n_embd, vocab_size)
x = blocks(x) # (2, 6, 4)
x = ln_f(x) # (2, 6, 4)
# (2, 6, 4) @ (4, vocab_size)
logits = lm_head(x) # (2, 6, vocab_size)
Here vocab_size
is the unique characters that occur in the dataset. In our example,
text = """
The sun dipped below the horizon, painting the sky with hues of orange and pink.
A gentle breeze rustled the leaves, creating a soothing melody.
In that peaceful moment, the world seemed to pause and breathe.
"""
chars = sorted(list(set(text)))
vocab_size = len(chars) # 30
Therefore the final output gives the logits, which when softmax is appleid to, tells us the probability of the next word occuring.