Stanford CS336 Lecture Notes 3 - Architectures and Hyperparameters
Continuing with Stanford’s CS336 course, this post covers the third lecture on transformer architectures and hyperparameters. We go through normalization strategies (pre-norm vs post-norm, LayerNorm vs RMSNorm), positional embeddings (from sinusoidal to RoPE), activation functions, attention variants (MQA, GQA, sparse attention), and common hyperparameter choices.
Original transformer (along with GPT 3/2/1, OPT, GPT-J, BLOOM) uses LayerNorm:
y=Var[x]+ϵx−E[x]∗γ+β
Most new models (e.g. LLaMA-family, PaLM, Chinchilla, T5) use RMSNorm:
y=n∣∣x∣∣22+ϵx∗γ
Removes bias term ⟹ Fewer parameters to store
Does not calculate and subtract the mean ⟹ Fewer operations
Even though most of the FLOPs happen during the matrix multiplication, bias term and normalization operation can still increase the runtime due to required data movement operations:
Bias Terms
Original transformer and earlier models used bias terms:
FFN(x)=max(0,xW1+b1)W2+b2
Most newer (but non-gated) implementations get rid of the bias term for performance gain and to store fewer parameters:
FFN(x)=σ(xW1)W2
Activations
Earlier, ReLU and GeLU were the most commonly used.
ReLU:
FF(x)=max(0,xW1)W2
GeLU:
FF(x)GELU(x)=GELU(xW1)W2:=xΦ(x)Φ(.) is the Gaussian CDF
New trend is to use gated activations. Idea is:
Project the input into two different vectors, let’s say g and p.
g is sent through a nonlinear function, turns into “gate weights”.
g and p are then element-wise multiplied.
So, instead of two weight matrices per FFN like in the original transformer, now we have 3.
That is why gated models use smaller internal dimensions by about 2/3, so the number of parameters for FFN is roughly the same.
Examples:
Serial and Parallel Layers
Standard transformer block applies first attention, and then the feedforward layer.
Serial implementation is still the standard practice, but there are alternative approaches for speed gain.
yy=x+Working on the normalized residual stream + attention layer outputsMLP(LayerNorm(x+Attention(LayerNorm(x))))=x+Working only on normalized inputMLP(LayerNorm(x))+Attention(LayerNorm(x))Serial Transformer BlockParallel Transformer Block
Position Embeddings
Sine - Cosine Embeddings
Use sine and cosine to keep track of the token positions. Embedding for a token x at position pos:
Idea is to create an embedding matrix PE, where each row corresponds to a position vector. In the beginning dimensions of the vector, you have higher frequency functions, and in the ending dimensions you have lower frequency functions.
So, dimensions near the end are better to understand the ballpark position, and the earlier dimensions are to precisely determine the position.
Assume dmodel=1024
This means we have 512 sine waves, and 512 cosine waves, each with different frequency. Let’s focus on the sine waves at i=2, and i=500.
Now, notice that the sine wave with i=500 changes much slower compared to the wave with i=2, because it has lower angular frequency, i.e. 0.000124<0.964662. That is why, that dimension changes much slower (∼7800 times slower).
Here, you can see how they change for the first 20 positions:
First function will complete its cycle around 6.51 positions and start repeating, while the second function completes its cycle in around 50645 positions:
So, the high frequency function can only determine position up to 6.51 units, and then has to start repeating.
Lower frequency function completes its cycle much later, so it can determine positions up to 50645 units. But because it changes so slowly, it cannot discriminate between e.g. pos=2 and pos=3 (See how close they are in the plot above).
That was the rationale behind using trigonometric functions.
They also thought model could learn to exploit the trigonometric identities to do “relative positioning too”, but this doesn’t seem to be the case. But idea was:
And this is just a linear operation, a matrix multiplication between some weights A and the positional embeddings matrix PE.
Furthermore, instead of using only sine, they also used cosine at the same frequency. This helps to differentiate between e.g. position 3 and 6, where for i=2, sines will have almost exactly the same value but cosines will be different:
Absolute Embeddings
Used in GPT-1, 2, 3 and OPT
Instead of using fixed positional embeddings like the trigonometric ones before, turn it into a trainable layer
Embed(x,pos)=vx+upos
Relative Embeddings
Used in T5, Gopher, Chinchilla
Instead of adding a learned or fixed position embedding to the token, use a learned embedding based on the offset between the “key” and “query” in the attention layers.
A fixed number of embeddings are learned, corresponding to a range of possible key-query offsets.
Start with the classic scaled dot-product attention between xi and xj:
Define two learnable vectors for the relationship between xi and xj, one on the value level and one on the key level: aijV,aijK∈Rda.
These relationship vectors can be shared between different attention layers.
Then, you can simply add these relative vectors to where you originally use key and value vectors.
To calculate the context vectors:
Context vector for token at position izi=j=1∑nαijModified value vector(Value Vector for xjxjWV+Value vector for positions [i] and [j]aijV)
To calculate the attention scores:
eij=dzxiWQQuery vector for xi(xjWKKey vector for xj+aijKKey vector for positions [i] and [j])TModified key vector for xj
Because it is impractical to model all possible differences, and arguably not needed, maximum distance ∣i−j∣ is clipped at some value k and then gets repeated. This means it will be possible to model 2k+1 unique distances. (k in both directions):
While theoretically it is nicer to think in terms of “modifying” the key and value vectors, efficient implementation actually does not modify the key matrices but adds another matrix multiplication, which is mathematically equivalent:
eij=dzxiWQ(xjWK)T+xiWQ(aijK)T
ALiBi (Attention with Linear Biases)
Similar to the idea of relative embeddings, can be considered a special case.
We are modifying the attention score calculation, using non-learnable bias terms:
eij=dzxiWQ(xjWK)T+m⋅(j−i)
Here, m is a predetermined scalar changing the slope for the different heads, going up to 281. For 8 attention heads, they have 211,221,…,281 and for 16 attention heads, it is 20.51,211,…,27.51,281. (So, basically start from 2n−8 and go up to 2−8).
Idea is, for each token, causal attention calculation is done only based on the per-token representation so far, and then bias term provides the position information by punishing the distant terms, i.e. a closer token will have −m+score and if the distance in between is 10 tokens, it will be −10m+score. And each head “punishes” the distance separately, based on their pre-determined slope:
As the distance gets longer, the m×distance term will start to dominate the actual score, so applying softmax to them will create outputs near 0, i.e. because you are pushing the far-away tokens to be more negative/smaller, and closer tokens only slightly, softmax will favor the closer tokens.
Expectation is, that this is going to learn to generalize to unseen context lengths, thanks to its recency bias. However, this could be problematic when you need long-term dependencies.
RoPE (Rotary Positional Embeddings)
Think of the modeling of the relative distance in the attention layers:
⟨fq(xm,m),fk(xn,n)⟩=g(xm,xn,m−n)
where xm is the query token at position m, and xn is the key token at position n. The goal is to end-up with a function g, that should calculate an attention score based on three inputs:
xm: Embedding for xm
xn: Embedding for xn
m−n: Distance between the positions of xm and xn
First, remember the rotation matrices in 2D, and how they are defined using sin and cos functions of the same frequency:
Original transformer used pairs of sines and cosines as well, but for absolute encoding. Now, idea is to use the same sines and cosines, but instead of adding them to the token embeddings, we will rotate the key and query vectors at attention layers. Rotation matrix is defined similarly:
So, we have d/2 different rotation matrices, which are defined based on cosine and sine functions with different frequencies. Base frequency is the same one as they had in the original transformer:
Θ={θi=100002(i−1)/d1,i∈[1,2,…,d/2]}
Then, you have to rotate your query and value vectors, and every two dimensions rotate in different frequencies (frequencies have been changed so their movement can be shown in the animation):
Then, you simply rotate both your query and key vectors when you are calculating the attention scores:
qm′kn′eij=RΘ,mdWqxm=RΘ,ndWkxn=d(RΘ,mdWqxm)T(RΘ,ndWkxn)=dxmTWqRΘ,(n−m)dRelative rotationWkxn(Rotated query vector for token at position m)(Rotated key vector for token at position n)
So, rotating the query vector and value vector by their absolute positions corresponds to using a rotation matrix based on relative positions, as:
Usually is scaled up to 4 times of the model dimension and scaled back down.
dff=4dmodel
GLU variants: Because researchers wanted to keep the parameter count similar, they tend to scale down the hidden size of the feedforward layer:
dff≈38dmodel
Number of heads and head dimension
Usually, number of heads and dimension of the heads is chosen in a way, such that their multiplication is equal to the model dimension:
dmodel=nheads×dhead
Depth (number of layers) vs. width (model dimension) debate
Earlier, there were very deep models (BLOOM, T5 v1.1) with dmodel/nlayer>170, and very wide models (T5, GPT-2) with dmodel/nlayer<50, but newer models usually have the following aspect ratio:
100<nlayerdmodel<160
The choice of depth and width is also affected by your networking constraints and the type of parallelisms you can do. e.g. Tensor parallel that lets you train wider networks need fast network, while pipeline parallel where you can shard the model per layer can get away with a slower network.
But empirical research (OpenAI Scaling paper) shows the sweet spot is around 70 - 150
Vocabulary sizes
Monolingual models usually have around 30,000 - 50,000 token vocabulary size
Multilingual and newer models have between 64,000 - 255,000 token vocabulary size
Training regularization
Initially, a lot of models were using dropout
Nowadays, it seems like most don’t use dropout anymore, but switched to using weight decay
However, weight decay is not used for regularization, as it looks like it has no effect on overfitting (i.e. ratio of training loss to validation loss), but rather has an interesting relationship to dynamic learning rates like cosine LR decay, and ends up facilitating faster training and higher accuracy. Furthermore, weight decay stabilizes training with bfloat16.
More on this at D'Angelo, Francesco, et al. "Why do we need weight decay in modern deep learning?." Advances in Neural Information Processing Systems 37 (2024): 23191-23223.
Stability Tricks
Softmaxes ⟹ The problem child
Used in the final layer and also attention layers
Solution for the Final Layer: Z-loss (from 2014 - Devlin et al., for decoding speed, later used for stability in 2022 by PaLM)
Z-loss refers to the normalization factor in softmax. Softmax is defined as:
P(x)i=σ(x)i=Z(x), i.e. Softmax Normalizerj=1∑dexp(xj)exp(xi)
It turns out, by encouraging the normalizer to be close to 1, we can get more stable training runs. As you can write the log likelihood as log softmax, loss calculation on softmax can be written as:
L(x)=log(Z(x)exp(xi))=log(exp(xi))−log(Z(x))
, assuming logit i represents the correct logit.
Then, to encourage the Z(x) to be 1, you can simply push the log(Z(x)) towards 0, with a coefficient of α where α determines the amount of “encouragement”:
So, the idea is to add additional MSE loss on log(Z(x)). In theory, you can apply this to all softmax layers but in practice, everyone just applies it to the last layer.
For the attention softmaxes, another trick is used, namely QK Norm.
Idea is simple. Before you calculate the attention with dot product of q and k, apply LayerNorm / RMSNorm on them.
Attention Heads
Cost of the Multi-Head Attention
There are various alternatives to multi-head attention, devised in order to make best use of the GPU time. To understand why, let’s calculate the cost of attention in each head.
Every head computes three projections first:
XWQ→ Projects inputs to the query vectors
XWK→ Projects inputs to the key vectors
XWV→ Projects inputs to the value vectors
(We will drop the batch size from calculations for simplicity)
Assuming X∈Rn×d and WQ,WK,WV∈Rd×da:
Computational cost of projections: O(ndda).
Memory cost: O(nd+dda)
You could also include the output activations for memory, i.e. O(nd+dda+nda) but since d>da almost always, I drop it.
Then, the resulting matrices Q,K∈Rn×da are multiplied to get QKT:
Computational cost: O(n2da)
Memory cost: O(nda+n2)
Now, the resulting matrix QKT∈Rn×n will have softmax applied to it.
Computational cost: O(n2)
Memory cost: O(n2)
Now that we have the attention scores, we want to create the context vectors by multiplying the attention scores and the value vectors. As we are multiplying two matrices with sizes (n×n)×(n×da):
Computational cost: O(n2da)
Memory cost: O(n2+nda)
Now, this is the key insight here. As we are doing this for every head, all of the calculations get multiplied with h (except the memory cost of the input matrix, it doesn’t get read from the memory for each head)! So, so far we have
Note: This can further simplify depending on the relationship between n and d, but with the modern models we are not sure what the sequence length is going to be compared to the dimension of the model
Memory cost:
O(Input to attention for projectionsnd+Output of the projectionshdda+Input for attention scorehnda+Output for attention scorehn2+Softmax Activationshn2+Input to context vector calculationhn2+Output of the context vector calculationhnda)→O(Input to attention for projectionsnd+Attention score outputs + softmax activations + context vector calculation inputhn2)
Then, after head outputs are combined (this is a memory operation, so you can maybe incur an O(nd) memory cost again but doesn’t change the calculations), you have the final projection CP, with C∈Rn×d,P∈Rd×d context matrix and projection matrix, respectively. This final operation has:
Computational cost: O(nd2)
Memory cost: O(nd+d2)
Combining everything together,
Total computational cost: O(hndda+hn2da+nd2)
Total memory cost: O(nd+hn2+d2)
Now, we can do some assumptions based on the general conventions. Usually, d=h×da. Then:
Computational cost: O(nd2+n2d)
Memory cost: O(nd+hn2+d2)
So, this computational cost and memory cost showed us something. With the current conventions, number of heads has negligible effect on the computational cost but memory cost is highly dependent on the number of heads. Ideally, we want to improve arithmetic intensity, that is the ratio of computational cost to memory cost. As we do not want to have smaller hidden dimensions or shorter sequences, we tend to play around with the number of heads to improve the arithmetic intensity.
Batching Multi-Head Attention
For simplicity, we dropped the batch from the calculations. Let’s add it back:
Computational Cost: O(bnd2+bn2d)
Memory Cost: O(bnd+bhn2+d2) (Projection matrix is independent of the batch size, so it does not get multiplied by b)
Arithmetic Intensity: O(bnd+bhn2+d2bnd2+bn2d)
During training, this can be accepted, because by batching you can parallelize all the operations and still utilize the GPUs fully.
However, during the inference time, each attention calculation has to wait for the memory movement. So, you get into Generate → IO Wait → Generate → IO Wait → Generate … workflow. This under-utilizes the GPU (if not enough parallel requests) and causes slow responses (due to waiting for memory movement).
Multi-Query Attention (MQA)
Proposed Solution: Have queries per head, but share the WK and WV
What does it achieve:
In MHA, for each head data movement per key and value matrices is O(hnda)=O(nd).
By having only one Key and Value matrix, we decrease the memory movements by a factor of h⟹O(nda)
Even though this does not change the full complexity analysis, it provides significant speed-ups in reality because you significantly reduce the data movement from KV cache (this is primarily an inference optimization).
Grouped-Query Attention (GQA)
Proposed Solution: MQA can decrease performance, so have shared key and value matrices, e.g. first key and value matrices will be used by the first 3 heads, the second will be used by the heads 3-6 and so on.
Sparse and Sliding Window Attentions
Basically, consider only the closest tokens:
Sliding Window attention:
Sparse Transformers:
Combining Long- and Short-Range Information
E.g. in Cohere Command A, every 4th layer is a full attention layer with no positional embeddings (NoPE).
Other attention layers are Sliding Window + Grouped Query Attention with RoPE.
References
Most of the images are taken from these resources:
Vaswani, Ashish, et al. “Attention Is All You Need.” Advances in Neural Information Processing Systems 30 (2017).
Shaw, Peter, Jakob Uszkoreit, and Ashish Vaswani. “Self-Attention with Relative Position Representations.” Proceedings of NAACL-HLT (2018).
Xiong, Ruibin, et al. “On Layer Normalization in the Transformer Architecture.” International Conference on Machine Learning, PMLR, 2020.
Ivanov, Andrei, et al. “Data Movement Is All You Need: A Case Study on Optimizing Transformers.” Proceedings of Machine Learning and Systems 3 (2021): 711-732.
Press, Ofir, Noah A. Smith, and Mike Lewis. “Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation.” arXiv preprint arXiv:2108.12409 (2021).
Su, Jianlin, et al. “RoFormer: Enhanced Transformer with Rotary Position Embedding.” Neurocomputing 568 (2024): 127063.
Devlin, Jacob, et al. “Fast and Robust Neural Network Joint Models for Statistical Machine Translation.” Proceedings of the 52nd Annual Meeting of the ACL (2014).
Chowdhery, Aakanksha, et al. “PaLM: Scaling Language Modeling with Pathways.” Journal of Machine Learning Research 24.240 (2023): 1-113.
D’Angelo, Francesco, et al. “Why Do We Need Weight Decay in Modern Deep Learning?” Advances in Neural Information Processing Systems 37 (2024): 23191-23223.
“Multi-Query Attention Is All You Need.” Fireworks AI.
Ainslie, Joshua, et al. “GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints.” arXiv preprint arXiv:2305.13245 (2023).
Child, Rewon. “Generating Long Sequences with Sparse Transformers.” arXiv preprint arXiv:1904.10509 (2019).
Jiang, Albert Q., et al. “Mistral 7B.” arXiv preprint arXiv:2310.06825 (2023).
Other figures are mostly created with the help of Claude Sonnet 4.5.