Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

Attention and Transformers

Sequence processing and recurrent neural networks

Many tasks require processing sequences rather than fixed-size vectors:

  • Language modeling: predict the next word or character.

  • Machine translation: map an input sentence to an output sentence.

  • Speech recognition, time series prediction, control problems, etc.

A recurrent neural network (RNN) processes a sequence x1,,xTx_1,\dots,x_T step by step, maintaining a hidden state hth_t that summarizes the past.

A simple (vanilla) RNN cell:

  • Hidden state update:

    ht=ϕ(Whhht1+Wxhxt+bh),h_t = \phi\big(W_{hh} h_{t-1} + W_{xh} x_t + b_h\big),

    where ϕ\phi is a nonlinearity (e.g. tanh\tanh or ReLU).

  • Output at time tt:

    yt=g(Whyht+by),y_t = g\big(W_{hy} h_t + b_y\big),

    where gg is typically a softmax for classification or identity for regression.

The same parameters (Whh,Wxh,Why,bh,by)(W_{hh},W_{xh},W_{hy},b_h,b_y) are used at every time step.

RNNs support different input–output patterns:

  • One-to-many: image captioning (one image → sequence of words),

  • Many-to-one: sentiment classification (sequence → one label),

  • Many-to-many: sequence labeling, translation (sequence → sequence).

Training RNNs and backpropagation through time

To train an RNN, we define a loss over the entire sequence:

  • For example, in language modeling, sum the cross-entropy over all time steps:

    L=t=1Tt(yt,y^t).L = \sum_{t=1}^T \ell_t(y_t, \hat{y}_t).

Training uses backpropagation through time (BPTT):

  1. Unroll the RNN over all time steps t=1,,Tt = 1,\dots,T.

  2. Perform a forward pass to compute all hidden states and outputs.

  3. Backpropagate gradients from the final time step back to the beginning.

Because the gradient has to pass through many repeated multiplications by WhhW_{hh} (and nonlinearities), we get:

  • Vanishing gradients when eigenvalues of WhhW_{hh} are mostly <1< 1 in magnitude.

  • Exploding gradients when eigenvalues are >1> 1 in magnitude.

As a consequence:

  • Simple RNNs struggle to learn long-term dependencies (information far back in time).

  • They can, in principle, represent such dependencies, but learning them with gradient descent is difficult.

Long-term dependencies and gated RNNs: LSTM and GRU

To address vanishing and exploding gradients, gated RNN architectures were introduced.

Long Short-Term Memory (LSTM)

An LSTM maintains:

  • A cell state ctc_t for long-term memory,

  • A hidden state hth_t for short-term / working memory.

At each time step, it uses gates to control information flow:

  • Forget gate:

    ft=σ(Wf[ht1,xt]+bf)f_t = \sigma(W_f [h_{t-1}, x_t] + b_f)
  • Input gate:

    it=σ(Wi[ht1,xt]+bi)i_t = \sigma(W_i [h_{t-1}, x_t] + b_i)
  • Candidate cell state:

    c~t=tanh(Wc[ht1,xt]+bc)\tilde{c}_t = \tanh(W_c [h_{t-1}, x_t] + b_c)
  • Output gate:

    ot=σ(Wo[ht1,xt]+bo)o_t = \sigma(W_o [h_{t-1}, x_t] + b_o)

Update equations:

ct=ftct1+itc~t,ht=ottanh(ct),c_t = f_t \odot c_{t-1} + i_t \odot \tilde{c}_t, \qquad h_t = o_t \odot \tanh(c_t),

where \odot denotes element-wise multiplication.

Key properties:

  • The cell state ctc_t has an additive update, which helps gradients flow over long time spans.

  • Gates learn to keep, forget, or overwrite information.

Gated Recurrent Unit (GRU)

The GRU is a simpler gated RNN without a separate cell state:

  • Reset (or relevance) gate:

    rt=σ(Wrhht1+Wrxxt+br)r_t = \sigma(W_{rh} h_{t-1} + W_{rx} x_t + b_r)
  • Candidate hidden state:

    h~t=tanh(Wch(rtht1)+Wcxxt+bc)\tilde{h}_t = \tanh\big(W_{ch}(r_t \odot h_{t-1}) + W_{cx} x_t + b_c\big)
  • Update gate:

    ut=σ(Wuhht1+Wuxxt+bu)u_t = \sigma(W_{uh} h_{t-1} + W_{ux} x_t + b_u)

Update:

ht=(1ut)ht1+uth~t.h_t = (1 - u_t) \odot h_{t-1} + u_t \odot \tilde{h}_t.

Compared to LSTM:

  • Fewer gates and parameters,

  • No explicit cell state; hidden state carries both long- and short-term information.

Both LSTMs and GRUs significantly improve the ability to learn long-term dependencies compared to vanilla RNNs.

Multi-layer RNNs

RNNs can be stacked to form deep recurrent networks:

  • The hidden state of layer \ell at time tt, ht()h_t^{(\ell)}, becomes the input to layer +1\ell+1 at the same time step.

For example, with LL layers:

ht(1)=f(1)(xt,ht1(1)),h_t^{(1)} = f^{(1)}(x_t, h_{t-1}^{(1)}),
ht(2)=f(2)(ht(1),ht1(2)),h_t^{(2)} = f^{(2)}(h_t^{(1)}, h_{t-1}^{(2)}),

and so on up to ht(L)h_t^{(L)}.

Benefits:

  • Higher layers can capture more abstract features of the sequence.

  • Deep RNNs (with LSTM or GRU units) often perform better than single-layer ones.

In practice:

  • High-performing RNN-based models often use a small number of recurrent layers (e.g. 2–4),

  • Not nearly as deep as modern convolutional or transformer-based architectures.

Word embeddings and distributional semantics

Discrete words are often represented as one-hot vectors:

  • A vocabulary of size V|V|,

  • Word wiw_i is represented by a vector eiRVe_i \in \mathbb{R}^{|V|} with a single 1 and the rest 0.

Problems with one-hot encoding:

  • No notion of similarity between words,

  • Vectors are high-dimensional and sparse.

Word embeddings map words to dense vectors:

  • Learn an embedding matrix ERV×dE \in \mathbb{R}^{|V| \times d},

  • Word ww is represented as vwRdv_w \in \mathbb{R}^d (a row of EE),

  • dd is typically in the hundreds or thousands (e.g. 300, 768, 1536, 3072).

Distributional semantics:

“You shall know a word by the company it keeps.”

Words are embedded so that those appearing in similar contexts have similar vectors (high dot product or cosine similarity).

Embeddings are learned by:

  • Training language models or skip-gram / CBOW models,

  • Or as part of larger architectures (e.g. seq2seq, transformers).

These embeddings serve as the input representation for RNNs and transformers.

Sequence-to-sequence models and neural machine translation

In neural machine translation (NMT), we model the conditional probability p(yx)p(y \mid x) of a target sentence y=(y1,,yT)y = (y_1,\dots,y_T) given a source sentence x=(x1,,xTx)x = (x_1,\dots,x_{T_x}).

An RNN encoder–decoder model works as follows.

Encoder

The encoder RNN reads the source sequence and produces hidden states:

ht=f(xt,ht1),t=1,,Tx.h_t = f(x_t, h_{t-1}), \quad t = 1,\dots,T_x.

The encoder summarizes the source sequence into a context vector cc:

c=q(h1,,hTx),c = q(h_1,\dots,h_{T_x}),

for example by taking the final hidden state hTxh_{T_x} or using a more complex aggregation.

Decoder (basic model without attention)

The decoder is another RNN that generates the target sequence word by word:

p(y)=t=1Tp(yty<t,c),p(y) = \prod_{t=1}^T p(y_t \mid y_{<t}, c),

with each conditional modeled as

p(yty<t,c)=g(yt1,st,c),p(y_t \mid y_{<t}, c) = g(y_{t-1}, s_t, c),

where sts_t is the decoder hidden state, updated by

st=f(st1,yt1,c).s_t = f(s_{t-1}, y_{t-1}, c).

Limitations:

  • The context vector cc is a fixed-size bottleneck summarizing the entire source sentence.

  • For long sentences, compressing all information into a single vector can limit performance.

Attention mechanisms were introduced to solve this bottleneck.

Encoder–decoder with attention (align and translate)

Instead of using a single context vector cc for all target words, attention-based models compute a separate context cic_i for each target position ii.

For each target word yiy_i:

  • Decoder hidden state:

    si=f(si1,yi1,ci).s_i = f(s_{i-1}, y_{i-1}, c_i).
  • Conditional probability:

    p(yiy<i,x)=g(yi1,si,ci).p(y_i \mid y_{<i}, x) = g(y_{i-1}, s_i, c_i).

Context vector as a weighted sum of encoder states

Let h1,,hTxh_1,\dots,h_{T_x} be encoder annotations (e.g. from a bidirectional RNN). The context vector is

ci=j=1Txαijhj,c_i = \sum_{j=1}^{T_x} \alpha_{ij} h_j,

where the attention weights αij\alpha_{ij} describe how much the decoder at position ii focuses on encoder position jj.

Weights are computed as

αij=exp(eij)k=1Txexp(eik),\alpha_{ij} = \frac{\exp(e_{ij})}{\sum_{k=1}^{T_x} \exp(e_{ik})},

where eije_{ij} is an alignment score between decoder state si1s_{i-1} and encoder state hjh_j:

eij=a(si1,hj).e_{ij} = a(s_{i-1}, h_j).

Here, a(,)a(\cdot,\cdot) is a small neural network (e.g. a feed-forward network).

Interpretation:

  • Attention learns soft alignments between source and target tokens.

  • The decoder directly looks back at all encoder states, solving the fixed bottleneck problem.

  • The attention weights αij\alpha_{ij} provide interpretable alignment maps.

General attention mechanism: queries, keys, values

A general way to view attention:

  • We are given a set of values viv_i indexed by keys kik_i.

  • We have a query qq.

  • Attention returns a weighted sum of the values, where weights depend on how well the keys match the query.

Analogy: a hashtable or key–value store:

  • Keys kik_i index values viv_i,

  • Query qq asks “which values are relevant now?”.

Mathematically:

  1. Compute a score between query and each key, e.g. using a similarity function K(q,ki)K(q,k_i):

    • Cosine similarity:

      K(q,ki)=qkiqki.K(q, k_i) = \frac{q \cdot k_i}{\lVert q \rVert \lVert k_i \rVert}.
  2. Convert scores into a probability distribution via softmax:

    αi=exp(βK(q,ki))jexp(βK(q,kj)),\alpha_i = \frac{\exp(\beta K(q, k_i))}{\sum_j \exp(\beta K(q, k_j))},

    where β\beta is a scaling parameter.

  3. Compute the attention output as a weighted sum:

    Attn(q;K,V)=iαivi.\text{Attn}(q; K,V) = \sum_i \alpha_i v_i.

Properties:

  • Produces a fixed-size representation regardless of the number of values.

  • The output is a selective summary of the values, determined by the query.

  • In neural models, queries, keys, and values are learned vectors.

Self-attention

In self-attention, queries, keys, and values all come from the same sequence.

Example:

  • Input sequence of token embeddings: x1,,xTx_1,\dots,x_T.

  • For each position tt, we compute:

    qt=WQxt,kt=WKxt,vt=WVxt,q_t = W^Q x_t, \qquad k_t = W^K x_t, \qquad v_t = W^V x_t,

    where WQ,WK,WVW^Q, W^K, W^V are learned matrices.

Intuition:

  • Each position in the sequence attends to other positions to gather relevant information.

  • Self-attention can capture dependencies between tokens regardless of distance (short or long).

Self-attention was first used inside RNN architectures (e.g. adding a memory tape), but in transformers it becomes the core building block without recurrence.

Vectorized self-attention and scaled dot-product attention

Given a sequence of input vectors stacked as rows in a matrix XRT×dX \in \mathbb{R}^{T \times d}:

  1. Compute queries, keys, and values:

    Q=XWQ,K=XWK,V=XWV,Q = X W^Q, \quad K = X W^K, \quad V = X W^V,

    where WQ,WK,WVRd×dkW^Q, W^K, W^V \in \mathbb{R}^{d \times d_k} (or similar).

  2. Compute unnormalized attention scores via dot products:

    E=QKRT×T,E = Q K^\top \in \mathbb{R}^{T \times T},

    where EijE_{ij} is the score of token ii attending to token jj.

  3. Apply softmax row-wise to get attention weights:

    A=softmax(E),A = \text{softmax}(E),

    so each row of AA sums to 1.

  4. Compute the attention output:

    Output=AV.\text{Output} = A V.

This is often written compactly as:

Attention(Q,K,V)=softmax(QK)V.\text{Attention}(Q,K,V) = \text{softmax}(QK^\top)\,V.

Scaled dot-product attention

For large dkd_k, dot products can have large variance, making softmax too peaked or unstable.

To stabilize, divide by dk\sqrt{d_k}:

Attention(Q,K,V)=softmax(QKdk)V.\text{Attention}(Q,K,V) = \text{softmax}\left(\frac{QK^\top}{\sqrt{d_k}}\right) V.

This is the scaled dot-product attention used in transformers.

Attention plus feed-forward layers

Self-attention alone is a linear operation with respect to the values VV (no element-wise non-linearities inside).

To enhance expressiveness, transformers add a position-wise feed-forward network after attention:

For each position ii:

  1. Take the attention output vector outputi\text{output}_i.

  2. Apply a small MLP (often two linear layers with a nonlinearity in between):

    mi=FFN(outputi)=W2σ(W1outputi+b1)+b2,m_i = \text{FFN}(\text{output}_i) = W_2 \,\sigma(W_1 \text{output}_i + b_1) + b_2,

    where σ\sigma is typically ReLU or GELU.

This feed-forward network operates independently at each position, but with shared parameters across positions.

Thus a transformer layer combines:

  • Multi-head self-attention for contextual mixing across positions,

  • Position-wise feed-forward networks for nonlinear transformations at each position.

Residual connections and layer normalization

To train deep transformer stacks effectively, three key techniques are used:

Residual connections

Instead of learning a mapping H(x)H(x) directly, layers learn a residual function F(x)F(x) and add the input back:

x(l+1)=x(l)+F(x(l)).x^{(l+1)} = x^{(l)} + F(x^{(l)}).

In transformer layers, residual connections wrap both the attention and the feed-forward sublayers:

  • xSelfAttention(x)x \to \text{SelfAttention}(x) \to add xx,

  • then xFFN(x)x' \to \text{FFN}(x') \to add xx'.

Residual connections help:

  • Maintain information as it flows through many layers,

  • Improve gradient flow during backpropagation.

Layer normalization

Layer normalization normalizes the activations across the features of a layer for each example:

  • For a vector zz (e.g. the features at a given position), compute mean and variance:

    μ=1dk=1dzk,σ2=1dk=1d(zkμ)2.\mu = \frac{1}{d} \sum_{k=1}^d z_k, \qquad \sigma^2 = \frac{1}{d} \sum_{k=1}^d (z_k - \mu)^2.
  • Normalize and rescale:

    LayerNorm(z)k=γkzkμσ2+ε+βk,\text{LayerNorm}(z)_k = \gamma_k \frac{z_k - \mu}{\sqrt{\sigma^2 + \varepsilon}} + \beta_k,

    with learnable parameters γk\gamma_k and βk\beta_k.

Layer normalization:

  • Stabilizes training by reducing internal covariate shift,

  • Replaces batch normalization in transformer-style models (works well with variable-length sequences and small batches).

Positional encodings

Self-attention treats the input as a set: the computation

Attention(Q,K,V)=softmax(QKdk)V\text{Attention}(Q,K,V) = \text{softmax}\left(\frac{QK^\top}{\sqrt{d_k}}\right) V

does not depend on the order of positions; it is permutation-invariant with respect to the values.

However, for language and many other sequences, order matters. To introduce order information, we add a positional encoding to each token embedding.

Let:

  • xtx_t be the embedding of token at position tt,

  • ptp_t be its positional encoding.

We define:

x~t=xt+pt.\tilde{x}_t = x_t + p_t.

Then x~t\tilde{x}_t is used as input to the transformer (for queries, keys, and values).

Sinusoidal positional encodings

One popular choice uses fixed sinusoidal functions:

For model dimension dmodeld_\text{model}:

  • For even indices 2i2i:

    PE(pos,2i)=sin(pos100002i/dmodel),\text{PE}(\text{pos}, 2i) = \sin\left(\frac{\text{pos}}{10000^{2i / d_\text{model}}}\right),
  • For odd indices 2i+12i+1:

    PE(pos,2i+1)=cos(pos100002i/dmodel).\text{PE}(\text{pos}, 2i+1) = \cos\left(\frac{\text{pos}}{10000^{2i / d_\text{model}}}\right).

Properties:

  • Different frequencies encode different granularities of position.

  • The representation is periodic in a controlled way, which can help extrapolate to longer sequences.

  • These encodings are fixed (not learned), though learned positional embeddings are also common.

Multi-head attention

Single-head attention allows each position to attend to others using a single similarity pattern.

However, we may want to focus on different aspects of the input simultaneously (e.g. syntax vs semantics, local vs global context).

Multi-head attention uses multiple attention heads in parallel:

For hh heads:

  • For head ii:

    headi=Attention(QWiQ,KWiK,VWiV),\text{head}_i = \text{Attention}(Q W_i^Q, K W_i^K, V W_i^V),

    with learned projections WiQ,WiK,WiVW_i^Q, W_i^K, W_i^V.

  • Concatenate all heads:

    Concat(head1,,headh),\text{Concat}(\text{head}_1,\dots,\text{head}_h),
  • Apply a final linear projection:

    MultiHead(Q,K,V)=Concat(head1,,headh)WO.\text{MultiHead}(Q,K,V) = \text{Concat}(\text{head}_1,\dots,\text{head}_h) W^O.

Benefits:

  • Each head can capture different types of relationships:

    • Short-range vs long-range,

    • Different dependency types,

    • Different subspaces of representation.

  • Overall, it increases model capacity without dramatically increasing depth.

Transformer encoder architecture

The transformer encoder is a stack of identical layers, each consisting of:

  1. Multi-head self-attention.

  2. Residual connection and layer normalization.

  3. Position-wise feed-forward network.

  4. Another residual connection and layer normalization.

For an encoder layer, with input XX (a sequence of vectors):

  1. Self-attention sublayer:

    X~=LayerNorm(X+MultiHeadSelfAttn(X,X,X)).\tilde{X} = \text{LayerNorm}\big(X + \text{MultiHeadSelfAttn}(X,X,X)\big).
  2. Feed-forward sublayer:

    Y=LayerNorm(X~+FFN(X~)).Y = \text{LayerNorm}\big(\tilde{X} + \text{FFN}(\tilde{X})\big).

The encoder input is:

  • Token embeddings plus positional encodings.

Stacking several such layers yields deep contextual representations HH for the input sequence, to be used by decoders or other heads.

Transformer decoder architecture

The decoder also consists of stacked layers, each with three main sublayers:

  1. Masked multi-head self-attention (over the decoder inputs).

  2. Multi-head cross-attention (encoder–decoder attention).

  3. Position-wise feed-forward network.

Each sublayer is wrapped in residual connections and layer normalization.

Let ZZ be the decoder input representations (shifted target embeddings plus positional encodings), and HH the encoder outputs.

For a decoder layer:

  1. Masked self-attention (causal masking):

    • The decoder at position tt should not attend to positions >t> t (future tokens).

    • Implemented by masking out scores in QKQK^\top before softmax.

    • Sub-layer:

      Z~1=LayerNorm(Z+MaskedMultiHeadSelfAttn(Z,Z,Z)).\tilde{Z}_1 = \text{LayerNorm}\big(Z + \text{MaskedMultiHeadSelfAttn}(Z,Z,Z)\big).
  2. Encoder–decoder (cross) attention:

    • Queries come from decoder (Z~1\tilde{Z}_1),

    • Keys and values come from encoder outputs HH: $$ \tilde{Z}_2 = \text{LayerNorm}\big(\tilde{Z}_1

      • \text{MultiHeadAttn}(\tilde{Z}_1, H, H)\big). $$

  3. Feed-forward sublayer:

    Y=LayerNorm(Z~2+FFN(Z~2)).Y = \text{LayerNorm}\big(\tilde{Z}_2 + \text{FFN}(\tilde{Z}_2)\big).

Finally, a linear layer followed by softmax projects decoder outputs to vocabulary logits for next-token prediction.

Key idea:

  • The decoder uses self-attention to model dependencies within the target sequence,

  • Cross-attention to condition on the entire encoded source, solving the bottleneck and vanishing gradient issues inherent in pure RNN-based seq2seq models.

Transformer design goals and complexity

The transformer architecture was designed with three main goals:

  • Low per-layer computational complexity (compared to RNNs).

  • Short path length between any pair of positions (facilitating long-range dependencies).

  • High parallelizability (important for GPU/TPU acceleration).

Rough comparisons (for sequence length nn, model dimension dd, convolution kernel size kk):

  • Self-attention:

    • Complexity per layer: O(n2d)O(n^2 d) (due to QKQK^\top),

    • Sequential operations: O(1)O(1),

    • Maximum path length: O(1)O(1) (any position can attend to any other in one step).

  • Recurrent layers:

    • Complexity per layer: O(nd2)O(n d^2),

    • Sequential operations: O(n)O(n) (cannot parallelize across time),

    • Maximum path length: O(n)O(n).

  • Convolutional layers:

    • Complexity per layer: O(knd2)O(k n d^2),

    • Sequential operations: O(1)O(1),

    • Maximum path length: O(logkn)O(\log_k n) (stacked convolutions expand receptive field).

Conclusion:

  • Self-attention trades O(n2)O(n^2) complexity for constant path length and high parallelism.

  • For many tasks with moderate sequence lengths and sufficient compute, this trade-off is extremely favorable, enabling large-scale pretraining and very deep models.

Summary

  • RNNs process sequences with hidden state but struggle with long-term dependencies due to vanishing/exploding gradients.

  • LSTMs and GRUs introduce gates and additive memory paths to mitigate these issues.

  • Seq2seq models with encoder–decoder architectures can perform neural machine translation, but early models suffered from a fixed-size bottleneck.

  • Attention mechanisms let models compute context-dependent weighted sums over representations, solving the bottleneck and improving performance and interpretability.

  • Self-attention extends attention to interactions within a single sequence and is the core component of transformers.

  • Transformers rely on:

    • Multi-head self-attention to model rich dependencies,

    • Residual connections and layer normalization for deep, stable training,

    • Scaled dot-product attention for numerical stability,

    • Positional encodings to inject order information.

  • The transformer encoder and decoder architectures replace recurrence with stacks of attention and feed-forward layers, enabling:

    • Highly parallel computation,

    • Short paths between tokens,

    • Efficient modeling of long-range dependencies.

These ideas underpin modern large language models and many attention-based architectures in vision, speech, and beyond.