Part VI of VI

Advanced Topics and Current Frontiers

The Transformer architecture in full computational detail, recurrent networks, theoretical foundations of batch normalization and residual learning, and the frontiers of efficient deep learning

Contents

§22 The Transformer Architecture

The Transformer (Vaswani et al., 2017) has become the dominant architecture in NLP and is increasingly displacing CNNs in vision. Its core innovation — self-attention — allows every position in a sequence to attend to every other position, replacing the sequential computation of RNNs with fully parallel computation. We derive every operation from first principles with complete cost analysis.

22.1 Self-Attention — Full Derivation

Scaled Dot-Product Attention

Given an input sequence $\bm{X} \in \R^{n \times d}$ where $n$ is the sequence length and $d$ is the embedding dimension:

Step 1 — Compute queries, keys, and values:

$$\bm{Q} = \bm{X}\bm{W}_Q, \quad \bm{K} = \bm{X}\bm{W}_K, \quad \bm{V} = \bm{X}\bm{W}_V$$

where $\bm{W}_Q, \bm{W}_K \in \R^{d \times d_k}$ and $\bm{W}_V \in \R^{d \times d_v}$.

Step 2 — Compute attention scores:

$$\bm{S} = \frac{\bm{Q}\bm{K}\T}{\sqrt{d_k}} \in \R^{n \times n}$$

Step 3 — Apply softmax (row-wise):

$$\bm{A} = \softmax(\bm{S}) \in \R^{n \times n}$$

Step 4 — Weighted sum of values:

$$\text{Attention}(\bm{Q}, \bm{K}, \bm{V}) = \bm{A}\bm{V} \in \R^{n \times d_v}$$
Scaled Dot-Product Attention — Data Flow and Cost X (n × d) Q (n × d_k) K (n × d_k) V (n × d_v) QKᵀ/√d_k (n × n) Softmax A (n × n) AV (n × d_v) Output (n × d) Per-Head Cost: Q,K,V projections: 3×2nd·d_k | QKᵀ: 2n²d_k | Softmax: ~5n² | AV: 2n²d_v | Output: 2nd·d_v Total attention: O(n²d) FLOPs + O(n²) memory for attention matrix
Figure 22.1. Scaled dot-product attention data flow. The input $\bm{X}$ is projected into queries, keys, and values. The attention matrix $\bm{A}$ is $n \times n$, creating the quadratic cost in sequence length. The attention-weighted values produce the output.
Self-Attention — Complete FLOP Breakdown (Single Head)
Step Operation Shape FLOPs
1a. Q projection $\bm{X}\bm{W}_Q$ $(n \times d) \times (d \times d_k)$ $2n d \cdot d_k$
1b. K projection $\bm{X}\bm{W}_K$ $(n \times d) \times (d \times d_k)$ $2n d \cdot d_k$
1c. V projection $\bm{X}\bm{W}_V$ $(n \times d) \times (d \times d_v)$ $2n d \cdot d_v$
2. Attention scores $\bm{Q}\bm{K}\T$ $(n \times d_k) \times (d_k \times n)$ $2n^2 d_k$
2b. Scale by $1/\sqrt{d_k}$ Element-wise $n \times n$ $n^2$
3. Softmax Row-wise softmax $n \times n$ $\approx 5n^2$
4. Weighted values $\bm{A}\bm{V}$ $(n \times n) \times (n \times d_v)$ $2n^2 d_v$

22.2 Multi-Head Attention (MHA)

Multi-Head Attention

Split the $d$-dimensional representation into $h$ heads, each with dimension $d_k = d_v = d/h$:

$$\text{MHA}(\bm{X}) = \text{Concat}(\text{head}_1, \ldots, \text{head}_h)\bm{W}_O$$ $$\text{head}_i = \text{Attention}(\bm{X}\bm{W}_{Q_i}, \bm{X}\bm{W}_{K_i}, \bm{X}\bm{W}_{V_i})$$

where $\bm{W}_{Q_i}, \bm{W}_{K_i} \in \R^{d \times d_k}$, $\bm{W}_{V_i} \in \R^{d \times d_v}$, $\bm{W}_O \in \R^{d \times d}$.

Multi-Head Attention — Complete Cost (for batch size $B$)
Component Parameters FLOPs (per sample)
Q, K, V projections (all heads combined) $3d^2$ $3 \times 2nd^2 = 6nd^2$
Attention scores (all heads) 0 $h \times 2n^2(d/h) = 2n^2 d$
Softmax (all heads) 0 $h \times 5n^2 = 5hn^2$
Value weighting (all heads) 0 $h \times 2n^2(d/h) = 2n^2 d$
Output projection $\bm{W}_O$ $d^2$ $2nd^2$
Total MHA $\mathbf{4d^2}$ $\mathbf{8nd^2 + 4n^2 d + 5hn^2}$

The dominant terms are the linear projections ($8nd^2$) and the attention computation ($4n^2d$). For short sequences ($n \ll d$), the projections dominate. For long sequences ($n \gg d$), the quadratic attention term dominates.

Crossover point: $8nd^2 = 4n^2d$ → $n = 2d$. For $d = 768$ (BERT-base), the crossover is at $n = 1{,}536$ tokens.

Attention Memory — The $n^2$ Problem

The attention matrix $\bm{A} \in \R^{B \times h \times n \times n}$ must be stored for backpropagation.

$$\text{Attention matrix memory} = B \times h \times n^2 \times \text{bytes\_per\_element}$$

For GPT-3 scale: $B=1$, $h=96$, $n=2048$, FP16: $1 \times 96 \times 2048^2 \times 2 = 805$ MB — just for one layer's attention matrix. With 96 layers, this is 75 GB for attention matrices alone.

22.3 Position-Wise Feed-Forward Network (FFN)

FFN Block
$$\text{FFN}(\bm{x}) = \text{GELU}(\bm{x}\bm{W}_1 + \bm{b}_1)\bm{W}_2 + \bm{b}_2$$

where $\bm{W}_1 \in \R^{d \times d_{\text{ff}}}$, $\bm{W}_2 \in \R^{d_{\text{ff}} \times d}$, and typically $d_{\text{ff}} = 4d$.

FFN — Computational Cost
Component Parameters FLOPs (per sample)
$\bm{x}\bm{W}_1 + \bm{b}_1$ (expand) $d \cdot d_{\text{ff}} + d_{\text{ff}}$ $2n \cdot d \cdot d_{\text{ff}}$
GELU activation 0 $\approx 14 \cdot n \cdot d_{\text{ff}}$
$\bm{h}\bm{W}_2 + \bm{b}_2$ (project back) $d_{\text{ff}} \cdot d + d$ $2n \cdot d_{\text{ff}} \cdot d$
Total FFN $\mathbf{2d \cdot d_{\text{ff}} + d_{\text{ff}} + d}$ $\approx \mathbf{4nd \cdot d_{\text{ff}}}$

With $d_{\text{ff}} = 4d$: Parameters = $8d^2 + 5d \approx 8d^2$. FLOPs = $16nd^2$.

The FFN has 2× the parameters and 2× the FLOPs of MHA (at $8d^2$ vs $4d^2$ params, $16nd^2$ vs $8nd^2$ FLOPs for the linear parts). The FFN is typically the larger component of a Transformer block.

22.4 Positional Encoding

Sinusoidal Positional Encoding (Original Transformer)
$$\text{PE}(pos, 2i) = \sin\!\left(\frac{pos}{10000^{2i/d}}\right), \quad \text{PE}(pos, 2i+1) = \cos\!\left(\frac{pos}{10000^{2i/d}}\right)$$

FLOPs: $nd$ (computed once, cached). Parameters: 0 (deterministic).

Learned positional embeddings (BERT, GPT): lookup table of $n_{\max} \times d$ parameters. FLOPs: 0 (just index into table). Memory: $n_{\max} \times d$ floats.

Rotary Position Embeddings (RoPE) (Su et al., 2021): applied multiplicatively to Q and K in each attention head. Cost: $\approx 6nd$ FLOPs per layer (small relative to attention).

22.5 Complete Transformer Block — Total Cost

One Transformer Block (Pre-Norm Variant)

LayerNorm → MHA → Residual → LayerNorm → FFN → Residual

Component Parameters FLOPs (per sample)
LayerNorm 1 $2d$ $\approx 7nd$
Multi-Head Attention $4d^2$ $8nd^2 + 4n^2d$
Residual add 0 $nd$
LayerNorm 2 $2d$ $\approx 7nd$
FFN (with $d_{\text{ff}} = 4d$) $8d^2 + 5d$ $\approx 16nd^2$
Residual add 0 $nd$
Total per block $\approx \mathbf{12d^2}$ $\approx \mathbf{24nd^2 + 4n^2d}$

For an $L$-layer Transformer:

$$\text{Total params} \approx 12Ld^2 + V \cdot d \quad \text{(+ embedding layer of size $V \times d$)}$$ $$\text{Total FLOPs per sample} \approx L(24nd^2 + 4n^2d)$$
Worked Example — GPT-2 Scale Models
Model $L$ $d$ $h$ $d_{\text{ff}}$ Params FLOPs/token
GPT-2 Small 12 768 12 3072 117M ~235M
GPT-2 Medium 24 1024 16 4096 345M ~690M
GPT-2 Large 36 1280 20 5120 774M ~1.55G
GPT-2 XL 48 1600 25 6400 1.56B ~3.1G

Approximate formula: FLOPs per token $\approx 2P$ (where $P$ = parameter count). This is because for a forward pass through a dense layer with $P$ parameters, there are $\approx 2P$ FLOPs. This "$2P$ per token" rule is widely used for quick estimates of Transformer training costs.

Transformer Training Cost — The $6PD$ Formula

For training a Transformer with $P$ parameters on $D$ tokens:

$$C \approx 6PD \text{ FLOPs}$$

This comes from: forward pass ≈ $2P$ FLOPs/token × backward pass multiplier of 3:

$$C = 3 \times 2P \times D = 6PD$$

For GPT-3 (175B params, 300B tokens): $C \approx 6 \times 175 \times 10^9 \times 300 \times 10^9 = 3.15 \times 10^{23}$ FLOPs ≈ 315 zetaFLOPs.

22.6 The Quadratic Attention Problem

The $O(n^2)$ Wall

Self-attention's cost has two components: $O(nd^2)$ for projections and $O(n^2d)$ for the attention matrix. The quadratic term creates severe scaling issues:

Sequence Length $n$ Attention FLOPs ($4n^2d$, $d$=1024) Attention Memory ($hn^2$, $h$=16, FP16)
512 1.07G 8 MB
2,048 17.2G 128 MB
8,192 275G 2 GB
32,768 4.4T 32 GB
131,072 70T 512 GB

Going from 512 to 131K tokens: 65,536× more attention FLOPs and memory. This is why long-context modeling is the central computational challenge of modern Transformer research.

22.7 Efficient Attention Variants

Method Complexity Key Idea Trade-off
FlashAttention (Dao et al., 2022) $O(n^2d)$ FLOPs, $O(n)$ memory Tiling to avoid materializing $n \times n$ matrix; fused kernel Exact (no approximation), 2–4× speedup
Multi-Query Attention (Shazeer, 2019) $O(n^2d/h)$ for KV Share K, V across all heads; only Q differs Minor accuracy loss; major KV cache reduction
Grouped-Query Attention (Ainslie et al., 2023) Between MHA and MQA $G$ groups of heads share KV ($1 < G < h$) Interpolates MHA/MQA quality-efficiency
Linear Attention (Katharopoulos et al., 2020) $O(nd^2)$ Replace softmax with kernel trick: $\phi(\bm{Q})\phi(\bm{K})\T\bm{V}$ Avoids $n^2$, but quality gap
Sparse Attention (Child et al., 2019) $O(n\sqrt{n})$ Attend only to fixed sparse patterns Misses some long-range interactions
Sliding Window (Beltagy et al., 2020) $O(nw)$, $w$ = window size Local attention within window $w$ Limited global context
FlashAttention — Why It Matters

FlashAttention (Dao et al., 2022) does not reduce the FLOP count — it computes exact standard attention. Its innovation is IO-aware: it tiles the computation to fit within GPU SRAM (on-chip memory), avoiding writing the full $n \times n$ attention matrix to HBM (slow off-chip memory). Result:

Memory: $O(n^2) \to O(n)$ (never materializes full attention matrix)

Speed: 2–4× faster (fewer HBM reads/writes)

Accuracy: Bit-identical to standard attention

FlashAttention is now the default attention implementation in most frameworks.

22.8 KV Cache for Autoregressive Inference

KV Cache

In autoregressive generation (GPT-style), each new token attends to all previous tokens. Without caching, generating token $t$ requires recomputing all $t$ K and V projections — total cost $O(t^2)$ per token, or $O(n^3)$ for the full sequence.

With KV cache: store the K and V matrices from all previous tokens. Generating token $t$ requires only:

(1) Computing $\bm{q}_t, \bm{k}_t, \bm{v}_t$ for the new token: $O(d^2)$

(2) Appending $\bm{k}_t, \bm{v}_t$ to the cache: $O(d)$

(3) Computing attention: $\bm{q}_t\T \bm{K}_{\text{cache}}$: $O(td)$

Total per token: $O(d^2 + td)$ instead of $O(td^2 + t^2d)$

KV Cache Memory
$$\text{KV cache} = 2 \times L \times n \times d \times \text{bytes\_per\_element}$$

The factor of 2 is for both K and V. Per layer, per token: $2d$ values.

Model $L$ $d$ KV cache per token (FP16) KV cache for 2K tokens
GPT-2 Small 12 768 36.9 KB 72 MB
LLaMA 7B 32 4096 524 KB 1.0 GB
LLaMA 70B 80 8192 2.6 MB 5.1 GB
GPT-3 175B 96 12288 4.7 MB 9.2 GB

For large models serving many users simultaneously, the KV cache can consume more GPU memory than the model weights themselves. This is the primary motivation for Multi-Query Attention and Grouped-Query Attention, which reduce KV cache by $h\times$ and $h/G\times$ respectively.

References for §22

[1] Vaswani, A. et al. (2017). Attention Is All You Need. NeurIPS 2017.

[2] Dao, T. et al. (2022). FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness. NeurIPS 2022.

[3] Shazeer, N. (2019). Fast Transformer Decoding: One Write-Head is All You Need. arXiv:1911.02150.

[4] Ainslie, J. et al. (2023). GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints. EMNLP 2023.

[5] Su, J. et al. (2021). RoFormer: Enhanced Transformer with Rotary Position Embedding. arXiv:2104.09864.


§23 Recurrent Neural Networks and LSTMs

23.1 Vanilla RNN

RNN Cell
$$\bm{h}_t = \tanh(\bm{W}_{hh}\bm{h}_{t-1} + \bm{W}_{xh}\bm{x}_t + \bm{b}_h)$$ $$\bm{y}_t = \bm{W}_{hy}\bm{h}_t + \bm{b}_y$$

where $\bm{h}_t \in \R^d$ is the hidden state, $\bm{x}_t \in \R^m$ is the input, $d$ is the hidden size.

RNN — Cost per Time Step
Component Parameters FLOPs per step
$\bm{W}_{hh}\bm{h}_{t-1}$ $d^2$ $2d^2$
$\bm{W}_{xh}\bm{x}_t$ $md$ $2md$
Addition + tanh 0 $\approx 7d$
Total (per step, excl. output) $d^2 + md$ $\approx 2d(d+m)$
Full sequence ($n$ steps) $d^2 + md$ (shared!) $2nd(d+m)$

Key property: FLOPs scale as $O(nd^2)$ — linear in sequence length, unlike the Transformer's $O(n^2d)$ attention term. However, RNN steps are sequential — step $t$ depends on step $t-1$ — preventing parallelization.

23.2 LSTM (Long Short-Term Memory)

LSTM Cell — Full Equations

Forget gate: $\bm{f}_t = \sigma(\bm{W}_f[\bm{h}_{t-1}, \bm{x}_t] + \bm{b}_f)$

Input gate: $\bm{i}_t = \sigma(\bm{W}_i[\bm{h}_{t-1}, \bm{x}_t] + \bm{b}_i)$

Candidate cell: $\tilde{\bm{c}}_t = \tanh(\bm{W}_c[\bm{h}_{t-1}, \bm{x}_t] + \bm{b}_c)$

Cell update: $\bm{c}_t = \bm{f}_t \odot \bm{c}_{t-1} + \bm{i}_t \odot \tilde{\bm{c}}_t$

Output gate: $\bm{o}_t = \sigma(\bm{W}_o[\bm{h}_{t-1}, \bm{x}_t] + \bm{b}_o)$

Hidden state: $\bm{h}_t = \bm{o}_t \odot \tanh(\bm{c}_t)$

LSTM — Complete Cost per Time Step

All four gates share the same structure: matrix multiply $[\bm{h}_{t-1}, \bm{x}_t]$ (size $d+m$) by a weight matrix (size $(d+m) \times d$), plus bias, plus activation. In practice, these are combined into one large matrix multiply:

$$[\bm{f}_t; \bm{i}_t; \tilde{\bm{c}}_t; \bm{o}_t] = \bm{W}_{4d} \cdot [\bm{h}_{t-1}; \bm{x}_t] + \bm{b}_{4d}$$

where $\bm{W}_{4d} \in \R^{4d \times (d+m)}$.

Component Parameters FLOPs per step
Combined gate matmul $4d(d+m) + 4d$ $8d(d+m)$
3 sigmoids + 1 tanh 0 $\approx 18d$
Cell update ($f \odot c + i \odot \tilde{c}$) 0 $3d$ (2 elem-wise mults + 1 add)
Output ($o \odot \tanh(c)$) 0 $\approx 7d$ (tanh + elem-wise mult)
Total LSTM per step $\approx 4d(d+m)$ $\approx \mathbf{8d(d+m)}$
Ratio vs. vanilla RNN $4\times$ $4\times$

An LSTM is exactly 4× the cost of a vanilla RNN per step, because it computes 4 gates instead of 1 hidden state transformation.

23.3 GRU (Gated Recurrent Unit)

GRU Cell

Reset gate: $\bm{r}_t = \sigma(\bm{W}_r[\bm{h}_{t-1}, \bm{x}_t])$

Update gate: $\bm{z}_t = \sigma(\bm{W}_z[\bm{h}_{t-1}, \bm{x}_t])$

Candidate: $\tilde{\bm{h}}_t = \tanh(\bm{W}[\bm{r}_t \odot \bm{h}_{t-1}, \bm{x}_t])$

Output: $\bm{h}_t = (1-\bm{z}_t) \odot \bm{h}_{t-1} + \bm{z}_t \odot \tilde{\bm{h}}_t$

Parameters: $3d(d+m)$ — $3\times$ vanilla RNN, $\frac{3}{4}\times$ LSTM

FLOPs per step: $\approx 6d(d+m)$ — $\frac{3}{4}\times$ LSTM

23.4 RNN vs. Transformer — Cost Comparison

Architecture Comparison for Sequence of Length $n$
Metric RNN/LSTM Transformer
FLOPs (per layer, sequence) $O(nd^2)$ $O(n^2d + nd^2)$
Parallelism (training) Sequential ($n$ serial steps) Fully parallel
Parallelism (inference) Sequential Parallel (but KV cache)
Long-range dependencies $O(n)$ steps for signal to propagate $O(1)$ — direct attention
Memory (training, per layer) $O(nd)$ (store hidden states) $O(n^2 + nd)$ (attention matrix)
Memory (inference) $O(d)$ (just current hidden state) $O(nLd)$ (KV cache)
Parameter sharing Same params every step Same params every position
Why Transformers Won

Despite RNNs having lower FLOPs per sequence (linear vs quadratic in $n$), Transformers dominate because:

(1) Parallelism: The entire sequence is processed simultaneously during training, fully utilizing GPU parallelism. RNN training requires $n$ sequential steps, leaving most GPU cores idle.

(2) Gradient flow: Attention provides a direct path for gradients across any distance (like a "skip connection in time"), avoiding the vanishing/exploding gradient problem that plagues RNNs even with LSTM/GRU.

(3) Scaling: Transformer performance scales predictably with model size and data (scaling laws), while RNN performance saturates earlier.

The key trade-off: Transformers trade more FLOPs for better parallelism and gradient flow, which on modern GPU hardware results in faster wall-clock training time despite higher FLOP count.

References for §23

[6] Hochreiter, S. & Schmidhuber, J. (1997). Long Short-Term Memory. Neural Computation, 9(8), 1735–1780.

[7] Cho, K. et al. (2014). Learning Phrase Representations using RNN Encoder-Decoder for Statistical Machine Translation. EMNLP 2014.

[8] Elman, J. L. (1990). Finding Structure in Time. Cognitive Science, 14(2), 179–211.


§24 Theoretical Foundations

24.1 Why Batch Normalization Works

The original BatchNorm paper (Ioffe & Szegedy, 2015) proposed that BN works by reducing "internal covariate shift" — the change in the distribution of layer inputs as previous layers' parameters change. However, Santurkar et al. (2018) showed this explanation is incomplete.

The Smoothness Explanation (Santurkar et al., 2018)

BatchNorm makes the loss landscape significantly smoother. Specifically, for a network with BatchNorm:

$$\|\nabla L(\bm{\theta}_1) - \nabla L(\bm{\theta}_2)\| \leq \beta \|\bm{\theta}_1 - \bm{\theta}_2\|$$

The Lipschitz constant $\beta$ of the gradient (the "smoothness") is much smaller with BN than without. This means:

(1) The gradient changes more predictably, so gradient descent steps are more reliable.

(2) Larger learning rates can be used without overshooting.

(3) Fewer iterations are needed to converge.

Computational implication: BN's per-step cost is ~2% overhead, but it reduces training iterations by 2–5× (enabling 5–10× larger learning rates). The net effect is a substantial reduction in total training FLOPs.

24.2 Why Residual Connections Work

Gradient Highway Interpretation

For a residual block $\bm{y} = F(\bm{x}) + \bm{x}$, the gradient of the loss w.r.t. any early layer $l$ can be written as:

$$\pd{L}{\bm{x}_l} = \pd{L}{\bm{x}_L} \prod_{k=l}^{L-1}\left(\bm{I} + \pd{F_k}{\bm{x}_k}\right)$$

Expanding the product, every term contains $\bm{I}$ (the identity), which provides a direct additive path for gradients from loss to any layer. Even if all the $\pd{F_k}{\bm{x}_k}$ terms are small (vanishing), the gradient through the skip connections is exactly $\pd{L}{\bm{x}_L}$ — it does not decay.

Without residual connections: $\pd{L}{\bm{x}_l} = \pd{L}{\bm{x}_L}\prod_{k=l}^{L-1}\pd{F_k}{\bm{x}_k}$, which vanishes or explodes exponentially with depth.

Ensemble Interpretation (Veit et al., 2016)

An $L$-layer residual network can be viewed as an ensemble of $2^L$ paths of different lengths (each block can be either "on" or bypassed via the skip connection). Veit et al. showed empirically that:

(1) Residual networks primarily use paths of moderate length (~$\sqrt{L}$).

(2) Deleting individual residual blocks has minimal impact (unlike deleting layers in plain networks).

(3) The ensemble interpretation explains the smooth loss landscape and good generalization of ResNets.

24.3 The Lottery Ticket Hypothesis

The Lottery Ticket Hypothesis (Frankle & Carlin, 2019)

A randomly initialized dense network contains a subnetwork ("winning ticket") that, when trained in isolation with the same initialization, matches the full network's accuracy in at most the same number of training iterations.

Computational implications:

(1) Dense networks are significantly over-parameterized — they contain sparse subnetworks that are sufficient.

(2) Finding the winning ticket currently requires training the full network first, then pruning — so it doesn't save training FLOPs (yet).

(3) It provides theoretical justification for post-training pruning achieving high compression with minimal accuracy loss.

(4) If we could identify winning tickets before training, we could train sparse networks from scratch — reducing training FLOPs proportionally to the sparsity ratio.

References for §24

[9] Santurkar, S. et al. (2018). How Does Batch Normalization Help Optimization? NeurIPS 2018.

[10] Veit, A., Wilber, M. J., & Belongie, S. (2016). Residual Networks Behave Like Ensembles of Relatively Shallow Networks. NeurIPS 2016.

[11] Frankle, J. & Carlin, M. (2019). The Lottery Ticket Hypothesis: Finding Sparse, Trainable Neural Networks. ICLR 2019.


§25 Current Frontiers in Efficient Deep Learning

25.1 Mixture of Experts (MoE)

Mixture of Experts

MoE replaces the dense FFN in each Transformer layer with $E$ "expert" FFNs, of which only $K \ll E$ are activated for each token. A learned gating network $G(\bm{x})$ selects which experts to activate.

$$\text{MoE}(\bm{x}) = \sum_{i=1}^{K} g_i(\bm{x}) \cdot \text{Expert}_i(\bm{x})$$

where $g_i$ are the top-$K$ gate values (from softmax over all $E$ experts).

Metric Dense Transformer MoE Transformer
Total parameters $P$ $\sim EP/2$ (experts have most params)
Active parameters per token $P$ $\sim P \cdot K/E$ (only $K$ experts active)
FLOPs per token $2P$ $\sim 2P \cdot K/E + \text{gating cost}$
Memory (inference) $\sim P$ (all params in memory) $\sim EP/2$ (all experts must be loaded)

Key insight: MoE allows scaling model capacity (total parameters) without proportionally scaling compute (FLOPs per token). For example, Switch Transformer (Fedus et al., 2022) uses $E=128$ experts with $K=1$, achieving GPT-3-level quality with a fraction of GPT-3's per-token FLOPs.

Challenge: All expert weights must reside in memory, so MoE models require more GPU memory than their FLOP count would suggest. Communication overhead in distributed settings is also significant.

25.2 State Space Models (SSMs)

State Space Models — Mamba (Gu & Dao, 2023)

SSMs are a class of sequence models based on linear recurrences that can be computed as convolutions during training (parallelizable) or as recurrences during inference (constant memory per step). The Mamba architecture adds selective gating (input-dependent state transitions), achieving Transformer-competitive performance.

Metric Transformer Mamba (SSM)
Training FLOPs/token $O(n \cdot d^2)$ (dominates) + $O(n^2 d)$ $O(nd^2)$ (no quadratic term)
Training parallelism Full Full (via conv form)
Inference memory $O(nLd)$ (KV cache) $O(Ld)$ (hidden state only)
Inference FLOPs/token $O(d^2 + nd)$ per layer $O(d^2)$ per layer (no sequence dependency)
Long-range dependencies Direct ($O(1)$ via attention) Indirect (compressed in state)

SSMs offer linear scaling with sequence length during both training and inference, making them particularly attractive for very long sequences (100K+ tokens). The trade-off is that long-range information must be compressed into the fixed-size state $\bm{h}_t \in \R^{d_{\text{state}}}$, whereas attention can directly access any past token.

25.3 Neural Architecture Search (NAS)

NAS for Efficiency

NAS automates the design of neural architectures by searching over a design space. For efficiency-focused NAS, the search objective explicitly includes a cost term:

$$\text{Objective} = \text{Accuracy} \times \left(\frac{\text{FLOPs}}{T}\right)^{-w}$$

where $T$ is the target FLOP budget and $w$ controls the trade-off. Prominent results:

Architecture Search Method Search Cost Result
NASNet (Zoph et al., 2018) RL-based ~450 GPU-days Better accuracy-FLOP than manual
EfficientNet-B0 (Tan & Le, 2019) RL-based ~1 GPU-day (efficient search) 77.1% top-1 at 390M FLOPs
DARTS (Liu et al., 2019) Gradient-based (differentiable) ~1 GPU-day Competitive with RL-based

NAS has found architectures that dominate manually designed ones on the accuracy-efficiency Pareto frontier. The search cost has decreased from thousands of GPU-days to approximately 1 GPU-day.

25.4 Sparsity-Aware Hardware

Hardware Support for Sparsity

The gap between theoretical sparsity benefits and actual speedup is closing as hardware evolves:

Hardware Sparsity Support Effective Speedup
NVIDIA A100 (Ampere) 2:4 structured sparsity (50%) 2× on sparse tensor cores
NVIDIA H100 (Hopper) 2:4 + improved scheduling 2× + better utilization
Cerebras CS-2 Unstructured sparsity natively Near-linear with sparsity
GraphCore IPU Block sparsity support Good for sparse workloads

The 2:4 sparsity pattern (exactly 2 zeros per 4-element block) is particularly interesting: it provides guaranteed 2× speedup with hardware support, and neural networks can be trained to this pattern with minimal accuracy loss (~0.5%). This represents a clean, practical middle ground between dense computation and the unrealized promise of arbitrary unstructured sparsity.

25.5 The Efficiency Frontier — Where We Stand

Summary of Efficiency Techniques by Impact
Technique FLOPs Reduction Memory Reduction Maturity
Mixed Precision (FP16/BF16) ~0% (faster throughput) ~50% activations Standard practice
Efficient architectures (DW-Sep, bottleneck) 3–10× Proportional Standard practice
INT8 Quantization (inference) 2–4× throughput 75% Standard practice
Gradient checkpointing +33% (more FLOPs) $O(L) \to O(\sqrt{L})$ Standard practice
Knowledge distillation 3–10× (student vs teacher) Proportional Widely used
FlashAttention ~0% FLOPs, 2–4× speed $O(n^2) \to O(n)$ Standard practice
Structured pruning 2–5× Proportional Emerging
MoE $K/E$ active fraction Higher total (all experts) Frontier research
State Space Models $O(n) \to O(n)$ (linear) $O(1)$ per step (inference) Frontier research
2:4 Sparsity (hardware) 2× with HW support ~50% Emerging
NAS Architecture-dependent Architecture-dependent Maturing
The Grand Picture — Efficiency Research Trajectory

Over the past decade, the deep learning community has achieved roughly 1000× improvement in efficiency for a given accuracy target (e.g., ImageNet 76% top-1), through a combination of:

(1) Architecture innovations (~100× from AlexNet to EfficientNet): depthwise separable convolutions, bottleneck blocks, compound scaling, attention mechanisms

(2) Numerical precision (~4×): FP32 → FP16/BF16 → INT8 for inference

(3) Hardware (~10×): GPU generations (Kepler → Ampere → Hopper), tensor cores, memory bandwidth improvements

(4) Software optimization (~3×): operator fusion, FlashAttention, optimized GEMM libraries

The multiplicative combination of these factors ($100 \times 4 \times 10 \times 3 \approx 12{,}000\times$) explains the dramatic increase in what is practically trainable and deployable. Each factor represents a distinct dimension of optimization, and their multiplicative composition is the key insight: progress in any one dimension multiplies the benefit of all others.

References for §25

[12] Fedus, W. et al. (2022). Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity. JMLR, 23(120), 1–39.

[13] Gu, A. & Dao, T. (2023). Mamba: Linear-Time Sequence Modeling with Selective State Spaces. arXiv:2312.00752.

[14] Zoph, B., Vasudevan, V., Shlens, J., & Le, Q. V. (2018). Learning Transferable Architectures for Scalable Image Recognition. CVPR 2018. (NASNet)

[15] Liu, H., Simonyan, K., & Yang, Y. (2019). DARTS: Differentiable Architecture Search. ICLR 2019.

[11] Frankle, J. & Carlin, M. (2019). The Lottery Ticket Hypothesis. ICLR 2019.


§ Complete References

Transformer Architecture

[1] Vaswani, A. et al. (2017). Attention Is All You Need. NeurIPS 2017.

[2] Dao, T., Fu, D. Y., Ermon, S., Rudra, A., & Ré, C. (2022). FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness. NeurIPS 2022.

[3] Shazeer, N. (2019). Fast Transformer Decoding: One Write-Head is All You Need. arXiv:1911.02150.

[4] Ainslie, J. et al. (2023). GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints. EMNLP 2023.

[5] Su, J. et al. (2021). RoFormer: Enhanced Transformer with Rotary Position Embedding. arXiv:2104.09864.

Recurrent Networks

[6] Hochreiter, S. & Schmidhuber, J. (1997). Long Short-Term Memory. Neural Computation, 9(8), 1735–1780.

[7] Cho, K. et al. (2014). Learning Phrase Representations using RNN Encoder-Decoder. EMNLP 2014.

[8] Elman, J. L. (1990). Finding Structure in Time. Cognitive Science, 14(2), 179–211.

Theoretical Foundations

[9] Santurkar, S. et al. (2018). How Does Batch Normalization Help Optimization? NeurIPS 2018.

[10] Veit, A., Wilber, M. J., & Belongie, S. (2016). Residual Networks Behave Like Ensembles. NeurIPS 2016.

[11] Frankle, J. & Carlin, M. (2019). The Lottery Ticket Hypothesis. ICLR 2019.

Efficient Deep Learning Frontiers

[12] Fedus, W. et al. (2022). Switch Transformers. JMLR, 23(120), 1–39.

[13] Gu, A. & Dao, T. (2023). Mamba: Linear-Time Sequence Modeling with Selective State Spaces. arXiv:2312.00752.

[14] Zoph, B. et al. (2018). Learning Transferable Architectures (NASNet). CVPR 2018.

[15] Liu, H. et al. (2019). DARTS: Differentiable Architecture Search. ICLR 2019.

Foundational Texts

[16] Goodfellow, I., Bengio, Y., & Courville, A. (2016). Deep Learning. MIT Press.

[17] Sze, V. et al. (2017). Efficient Processing of Deep Neural Networks. Proc. IEEE, 105(12), 2295–2329.


End of Part VI — Series Complete

This concludes the six-part deep dive into Deep Neural Networks: Computational Foundations and Efficiency Analysis.

Parts I–VI collectively cover 25 sections, from basic linear algebra through Transformer architectures and current frontiers, with every mathematical derivation, computational cost, and memory requirement tracked in full detail.

Back to Notebook Index
Total visits:
§
Page visits: