Transformer Efficiency Techniques
Flash Attention
- Computes attention in tiles/blocks rather than materializing the full attention matrix
- Uses online softmax to compute attention incrementally, recomputing values during backward pass instead of storing them
- Reduces memory from to while maintaining exact attention (not an approximation)
- Key insight: memory I/O is the bottleneck, not FLOPs—trading compute for memory access is a net win on modern GPUs
KV Cache
- During autoregressive generation, keys and values for previous tokens don't change—cache them instead of recomputing
- Each new token only needs to compute its own and attend to the cached history
- Reduces per-token inference from to attention computation
- Trade-off: memory grows linearly with sequence length , which is why techniques like sliding window attention help for long contexts
Grouped Query Attention (GQA)
- Instead of separate K/V heads per query head (MHA) or one K/V for all queries (MQA), use groups
- Multiple query heads share the same K/V heads within a group (e.g., query heads, KV heads = 4 queries per KV)
- Reduces KV cache size by factor of while retaining most of MHA's expressiveness
- Sweet spot between MQA's efficiency and MHA's quality—used in Llama 2 70B, Mistral, etc.
Rotary Embeddings (RoPE)
- Encodes position by rotating query and key vectors in 2D subspaces based on position index
- Relative position naturally emerges: depends on due to rotation properties
- The rotation matrix applied to position uses angles
- No learned position embeddings—positions are encoded through geometric rotation
- Enables length extrapolation (with modifications like NTK-aware scaling or YaRN) beyond training context
Notable LLM Architectures & Papers
Flash Attention
- Tri Dao's IO-aware algorithm that avoids materializing the attention matrix to HBM
- Tiles computation into blocks that fit in SRAM, computing softmax incrementally via online normalization trick
- Recomputes attention in backward pass rather than storing activations—trades extra FLOPs for → memory
- Flash Attention 2 adds better work partitioning across warps and reduces non-matmul FLOPs for ~2x speedup
- Flash Attention 3 uses tensor cores for softmax, warp specialization, and FP8 support for Hopper GPUs
Paper: FlashAttention: Fast and Memory-Efficient Exact Attention
OLMoE
- Mixture-of-Experts variant of OLMo using top-k routing (typically k=2 or k=8 active experts)
- Each token routed to subset of experts—1B active params from 7B total, for example
- Uses auxiliary load balancing loss to prevent expert collapse (all tokens going to few experts)
- Achieves better performance per FLOP than dense models by decoupling parameters from compute
Paper: OLMoE: Open Mixture-of-Experts Language Models
Montessori Instruct
- Optimizes the teacher LLM to generate synthetic training data tailored to the student's learning preferences
- Uses influence functions to measure how each synthetic data point affects student's reference loss—positive influence = helpful, negative = harmful
- Constructs preference pairs from high/low influence data and trains teacher with DPO to favor generating influential examples
- Key insight: a weaker teacher optimized for the student outperforms a stronger teacher (GPT-4o) using standard synthesis
Paper: Montessori-Instruct: Generate Influential Training Data Tailored for Student Learning