AI Notes

Personal notes on AI concepts, techniques, and architectures.

Transformer Efficiency Techniques

Flash Attention

  • Computes attention in tiles/blocks rather than materializing the full N×NN \times N attention matrix
  • Uses online softmax to compute attention incrementally, recomputing values during backward pass instead of storing them
  • Reduces memory from O(N2)O(N^2) to O(N)O(N) while maintaining exact attention (not an approximation)
  • Key insight: memory I/O is the bottleneck, not FLOPs—trading compute for memory access is a net win on modern GPUs

KV Cache

  • During autoregressive generation, keys KK and values VV for previous tokens don't change—cache them instead of recomputing
  • Each new token only needs to compute its own K/VK/V and attend to the cached history
  • Reduces per-token inference from O(N)O(N) to O(1)O(1) attention computation
  • Trade-off: memory grows linearly with sequence length LL, which is why techniques like sliding window attention help for long contexts

Grouped Query Attention (GQA)

  • Instead of separate K/V heads per query head (MHA) or one K/V for all queries (MQA), use GG groups
  • Multiple query heads share the same K/V heads within a group (e.g., Hq=8H_q = 8 query heads, Hkv=2H_{kv} = 2 KV heads = 4 queries per KV)
  • Reduces KV cache size by factor of Hq/HkvH_q / H_{kv} while retaining most of MHA's expressiveness
  • Sweet spot between MQA's efficiency and MHA's quality—used in Llama 2 70B, Mistral, etc.

Rotary Embeddings (RoPE)

  • Encodes position by rotating query and key vectors in 2D subspaces based on position index mm
  • Relative position naturally emerges: qmknq_m \cdot k_n depends on (mn)(m - n) due to rotation properties
  • The rotation matrix RθmR_\theta^m applied to position mm uses angles θi=100002i/d\theta_i = 10000^{-2i/d}
  • No learned position embeddings—positions are encoded through geometric rotation
  • Enables length extrapolation (with modifications like NTK-aware scaling or YaRN) beyond training context

Notable LLM Architectures & Papers

Flash Attention

  • Tri Dao's IO-aware algorithm that avoids materializing the NtimesNN \\times N attention matrix to HBM
  • Tiles computation into blocks that fit in SRAM, computing softmax incrementally via online normalization trick
  • Recomputes attention in backward pass rather than storing activations—trades O(N)O(N) extra FLOPs for O(N2)O(N^2)O(N)O(N) memory
  • Flash Attention 2 adds better work partitioning across warps and reduces non-matmul FLOPs for ~2x speedup
  • Flash Attention 3 uses tensor cores for softmax, warp specialization, and FP8 support for Hopper GPUs

Paper: FlashAttention: Fast and Memory-Efficient Exact Attention

OLMoE

  • Mixture-of-Experts variant of OLMo using top-k routing (typically k=2 or k=8 active experts)
  • Each token routed to subset of experts—1B active params from 7B total, for example
  • Uses auxiliary load balancing loss to prevent expert collapse (all tokens going to few experts)
  • Achieves better performance per FLOP than dense models by decoupling parameters from compute

Paper: OLMoE: Open Mixture-of-Experts Language Models

Montessori Instruct

  • Optimizes the teacher LLM to generate synthetic training data tailored to the student's learning preferences
  • Uses influence functions to measure how each synthetic data point affects student's reference loss—positive influence = helpful, negative = harmful
  • Constructs preference pairs from high/low influence data and trains teacher with DPO to favor generating influential examples
  • Key insight: a weaker teacher optimized for the student outperforms a stronger teacher (GPT-4o) using standard synthesis

Paper: Montessori-Instruct: Generate Influential Training Data Tailored for Student Learning