../gpu1-self-attention
GPU Programming 1: Basic Multi-head Self Attention
Published:
Overview
Skip to the Pre-Reqs section at the bottom, if anything seems confusing. There might be background material you need to learn.
Transformers are great because they 1) significantly speed up learning times for a class of problems compared to recurrent/convolutional layers, and 2) allow training models that perform better than before.
Models utilizing transformers can be trained faster because a large matrix multiplication replaces the temporal/convolutional loop from recurrent/convolutional layers. Instead of generating tokens one at a time to get to ultimate hidden state as in recurrent case (or convolving over the input, as in convolutional layers), you just pass in the entire sequence at once via a huge LEN_SEQUENCE x LEN_SEQUENCE matrix. Then, use matrix masking to zero out weights in upper-diagonal during training (matrix masking) to ensure there's no "cheating" during training. Each prediction only looks at relevant prefix of the sequence. You do this all via one matrix multiplication, so you can keep the GPU better utilized.
Models using transformers also have better performance because, among other things, they can encode far-away relationships between tokens better and more directly than recurrent layers (hidden state) and covolutional layers (composition of convolutions).
Additional tricks also employed on top to ensure necessary information is present and model is stable. For example, positional encoding (using some trig) is used to make sure tokens encode positional data. Residual layers are also used to prevent vanishing gradients, and several constant divisions and normalization layers are used to prevent exploding gradients. Lots of "heads" (aka projection matrices, aka the many W_k, W_q, W_v matrices; lots of weights basically) are used to learn ideas in different spaces, and an FFN is used to aggregate everything in the end.
In the end, a transformer is a very interesting and seemingly simple primitive that can be repeated to build a large deep network. It is not surprising that it performs so well compared to recurrent/convolutional layers, though the question of whether there's something even better that'll soon be discovered is fascinating.
Reality
This is the transformer diagram presented by the authors of the seminal paper for a language translation task:
Source: Attention Is All You Need
It follows the classic encoder-decoder pattern. However, this is actually not used for most of today's usecases. LLMs, for example, only use a strict subset of the Decoder portion:
- You have a sequence of tokens
- Place them in some embedding space
- Apply positional encoding
- Apply masked multi-head attention (this blog! tho we skip the masking)
- Residual + Normalization layer (stability of gradients and weights)
- Note: This is called Post-Normalization. Most models these days make use or Pre-Normalization, woth Normalization Layer applied to the input rather than the output, as it appears to work better in practice. Residual layer is still applied on the output, as it doesn't make sense to apply it on the input.
- Feed forward network
- Residual + Normalization layer
You just repeat this over and over and you have yourself a powerful model that can be used to train an LLM.
You can train a GPT2 for <$100 with this!
Implemetation of Multi-Head Self Attention
Multi-head attention is the most important component of a transformer. Let's implement it here!
First, let's create a basic Python class that defines the Attention we want.
import torch
from torch.nn.parameter import Parameter
class SelfAttention:
# batch: Number of sentences/sequences as input in one go
# seq_len: Max length of sequence
# embedding_dims: Dimension of embedding for each token/element in sequence
def __init__(self, heads, batch, seq_len, embedding_dims):
self.batch = batch
self.seq_len = seq_len
self.embedding_dims = embedding_dims
self.heads = heads
self.projection_dims = embedding_dims // heads
self.W_K = Parameter(data=torch.randn(heads, self.projection_dims, embedding_dims))
self.W_Q = Parameter(data=torch.randn(heads, self.projection_dims, embedding_dims))
self.W_V = Parameter(data=torch.randn(heads, self.projection_dims, embedding_dims))
self.W_O = Parameter(data=torch.randn(embedding_dims, embedding_dims))
In real life, you're going to want to process a batch of data at once, so we define a batch.
seq_len is the length of the sequence (or, context).
heads is the number of heads we want. This is what leads to "multi-head" in "multi-head attention".
The __init__ basically lays the ground work for
Source: Attention Is All You Need
Now, let's define our forward-pass function. PyTorch will define a computational graph under the hood that will be capable of passing back gradients and updating model weights. Pre-reqs have material that shows how this works, if you're interested.
def forward(self, data):
# b s e
assert(data.shape == (self.batch, self.seq_len, self.embedding_dims))
# First, compute the K, Q, V projections
# shape: [Batch - b, Heads - h, SeqLen - s, HeadDim - p]
K_data = torch.einsum('hpe,bse->bhsp', self.W_K, data)
Q_data = torch.einsum('hpe,bse->bhsp', self.W_Q, data)
V_data = torch.einsum('hpe,bse->bhsp', self.W_V, data)
# Second, compute K*Q^T projections, then divide by sqrt(p) to keep
# gradients stable, as recommended by the paper
# shape: [Batch, Heads, SeqLen, SeqLen]
K_QT = torch.einsum('bhip,bhjp->bhij', K_data, Q_data)
K_QT_stable = K_QT / (self.projection_dims ** 0.5)
# Third, apply softmax
# shape: [Batch, Heads, SeqLen, SeqLen]
K_QT_softmaxed = torch.softmax(K_QT_stable, -1)
# Fourth, apply V
# shape: [Batch, Heads, SeqLen, HeadDim]
token_attentions = torch.einsum('bhij,bhjp->bhip', K_QT_softmaxed, V_data)
# Fifth, concatenate to create d_emebdding again
# shape: [Batch, SeqLen, Embedding]
token_attentions = torch.einsum('bhsp->bshp', token_attentions) # permute to get h adjacent tp p first
token_attentions_concatenated = torch.reshape(token_attentions, (self.batch, self.seq_len, self.embedding_dims))
# Sixth, run thru W_O
# shape: [Batch, SeqLen, Embedding]
return torch.einsum('ef,bsf->bse', self.W_O, token_attentions_concatenated)
That's really all there is! It's pretty simple really. Project your sequence to via a bunch (number of "multi-heads" you want) of unique W_Q, W_K, W_V matrices. In every projected space, figure out relationship between a token and every other token to create a new sequence that encodes all this meaning. Un-project data from all the projected spaces out to our original higher-dimensional space. Done! torch.einsum does all the heavy lifting in PyTorch here. Lots of matrix multiplications, which will be a huge focus of the remaining blogs in this GPU Programming sequence.
Source: Attention Is All You Need
Pre-reqs
Understand linear algebra:
Understand backprop, gradient descent and basics of PyTorch:
Understand what a transformer is:
Understand einstein summation notation for tensors:
- PyTorch einsum docs
- Good, but not sufficient
- Notes from Iva from Tesla Institute
- Better, with nice examples
- Wolfram Mathworld's Intro
- Dry, formulaic
- Faculity of Khan (NOT Khan Academy)
- Ok
- Ask LLM to give you Basic, Intermediate, Advanced exercises to make sure you understand
torch.einsum