Quick orientation
You’re comfortable with Transformers and KV caches, but you want a clear path from intuition → math → engineering. This blog shows:
- What attention sinks are and why they appear.
- How softmax makes them a feature, not a bug.
- Why evicting them destabilizes long-context generation.
- A tiny cache policy that fixes it, plus practical guardrails.
1) An intuition you can’t unsee
Imagine your model mid-generation, producing a mundane connective like “and” or “the.” There’s no strong evidence for any specific past token, yet attention weights must still sum to 1. Where does that probability mass go? In many trained LMs, a few early positions, often the very first tokens consistently soak up the “excess.” These are attention sinks.
They act like ballast in a ship: most of the time you barely notice them, but remove them and the whole vessel rocks violently. That’s exactly what happens when a sliding window quietly evicts the beginning of the conversation.
2) Softmax makes sinks inevitable
2.1 Softmax as a forced budget
For a single head at step \(t\):
\[ \alpha_{t,j} \;=\; \frac{\exp(z_{t,j})}{\sum_{i=1}^{m}\exp(z_{t,i})}, \qquad z_{t,j} \;=\; \frac{Q_t \cdot K_j}{\sqrt{d_k}}. \]
The weights \(\alpha_{t,\cdot}\) must sum to 1 no matter what. During pretraining, harmless, slightly-high baselines at a few positions (e.g., BOS) get reinforced by gradient flow and become reliable “default recipients” of probability mass: the sinks.
2.2 A toy model of sink bias
Suppose \(S\) earliest positions carry a small logit lift \(\delta>0\) over the rest:
\[ z_s = \mu + \delta \quad (s\le S), \qquad z_r = \mu \quad (S < r \le W). \]
Softmax denominator: \[ D = \sum_{j=1}^{W} e^{z_j} = e^{\mu}\left(S e^{\delta} + (W-S)\right). \] With sinks present, \(D\) has a large, low-variance floor \(S e^{\mu+\delta}\). Remove them and \(D\) shrinks to \(e^{\mu}W\), making the distribution much more sensitive to tiny logit changes.
2.3 Sensitivity and stability
The softmax Jacobian gives \[ \frac{\partial \alpha_{t,j}}{\partial z_{t,k}} = \alpha_{t,j}\left(\delta_{jk}-\alpha_{t,k}\right). \] Larger \(D\) (thanks to sinks) implies smaller \(\alpha\) volatility to logit perturbations, damping error propagation through layers and time. Intuitively: sinks keep the “temperature” effectively lower when evidence is weak.
3) Why naïve sliding windows collapse
In a sliding window of width \(W\), you retain only the most recent \(W\) tokens. Early sinks eventually fall off the left edge. Two compounding issues follow:
- Normalization shock. Without the sink floor, \(D\) is smaller and more volatile; softmax behaves “hotter,” so head weights swing too much between unrelated keys.
- Value over-mixing. Unstable attention blends unrelated \(V_i\) vectors; across layers and steps, the noise compounds into gibberish.
Field note: if you plot attention during a long, low-information stretch, you can watch heads latching onto those first few tokens. Evict them and the plot turns chaotic almost immediately.
4) A tiny fix that works: pin a few early tokens
The minimal, model-agnostic remedy is a cache policy:
- Pin the first \(S\) tokens permanently in the KV cache (e.g., \(S=4\)).
- Slide a window of size \(W-S\) over the rest.
Formally, with time \(t\):
\[ \text{KV\_cache}(t) \;=\; \{(K_j,V_j)\}_{j=1}^{S} \;\cup\; \{(K_j,V_j)\}_{j=t-(W-S)+1}^{t}. \]
This preserves the stabilizing denominator floor while keeping memory \(\mathcal{O}(W)\) and compute per token \(\mathcal{O}(W\cdot d)\).
5) Choosing S and W: rules of thumb
- \(S\) (pinned sinks): start with \(S=4\) if your model wasn’t trained with a dedicated sink token. If your pretraining inserted a special start token designed as a sink, try \(S=1\).
- \(W\) (total window): pick for latency/VRAM. A good starting point is the max your target hardware can serve at <95% memory util with your batch size.
- Heuristic check: if attention heatmaps show several heads consistently attending to positions 1–4, keep \(S\ge 4\). If most mass concentrates on the first token, you can experiment with \(S=1\) or \(2\).
6) Implementation (step-by-step)
6.1 Index policy
# t: current step (1-indexed for clarity)
# W: total KV window (including sinks)
# S: number of pinned early tokens
def sink_aware_kv_indices(t, W, S):
assert 1 <= S <= W
start_recent = max(S+1, t - (W - S) + 1)
keep = list(range(1, S+1)) + list(range(start_recent, t+1))
return keep
6.2 Integrating with your cache
- When you append new \(K_t, V_t\), compute
keepand evict keys/values not in it. - With paged KV caches, pin the page(s) that hold indices \(1..S\); page the rest normally.
- In tensor parallel, ensure each shard pins the same logical positions. If you shard by sequence, pin on the rank that “owns” \(1..S\); if you shard by hidden dim, everyone keeps the same indices.
6.3 Mixed precision & quantization
- Pinning stabilizes activations, often reducing outliers handy if you’re flirting with 4-bit KV.
- Keep the first \(S\) values in a slightly safer format if needed (e.g., fp16) while the rest run in int4.
6.4 Telemetry you should log
- Per-head attention mass on indices \(1..S\) over time.
- Perplexity vs. position (watch for post-eviction spikes in the control baseline).
- Activation max/percentiles before/after the policy change.
7) A worked example (thought experiment)
Suppose a head’s logits among recent tokens hover around \(\mu\), with the pinned first four tokens around \(\mu+\delta\) where \(\delta=0.4\). With \(W=512\), \(S=4\):
\[ D_{\text{with sinks}} = e^{\mu}\left(4e^{0.4}+508\right) \approx e^{\mu}(4\cdot1.49 + 508) \approx e^{\mu}\cdot 514. \]
If you drop the sinks, \(D = e^{\mu}\cdot 512\). That looks similar, but the effect compounds across many heads and layers, and, crucially, the pinned terms are stable over time. In noisy stretches, that stability is the difference between smooth continuation and a runaway drift.
The practical lesson: you don’t need a big \(S\); you need a consistent one.
8) Debugging & diagnostics
8.1 How to tell you’ve found sinks
- On low-information tokens (e.g., punctuation, fillers), several heads place non-trivial mass on positions 1–\(S\).
- The pattern persists across prompts and layers (though magnitudes differ).
8.2 When things still go sideways
- Collapse after minutes: increase \(S\) from 2→4, or bump \(W\) by 64–128.
- Latency spikes: your eviction is too granular. Switch to page-aligned eviction and pin whole pages that include 1–\(S\).
- Quantization artifacts: keep the sink values at a safer precision or recalibrate the quantizer with sink-heavy traces.
9) Variations you may encounter
- Dedicated sink token: Some training setups prepend a special token whose learned key/value acts as a clean reservoir, letting you use \(S=1\) at inference.
- Per-head scalar “sink”: An architectural variant adds a learnable scalar per head to capture spare probability mass—parameter-cheap and easy to retrofit.
- Alternative normalizations: Methods that relax the sum-to-one constraint can reduce sink formation, but they change the model’s behavior and training dynamics.
10) FAQ
Q1: Why not just increase the window?
You should, if you can. But even with a big \(W\), evicting the wrong early tokens can still trigger instability. Pinning a handful costs almost nothing and protects you when memory pressure forces a smaller \(W\).
Q2: Is this equivalent to adding BOS every time?
No. The model’s learned distribution uses specific early positions as sinks. Re-injecting BOS later doesn’t replicate their learned key/value statistics.
Q3: Do all heads use the same sinks?
Not exactly. Some heads use them heavily, others barely. Pinning the same first \(S\) indices covers all heads since the KV cache is shared.
11) Deployment checklist
- Pick \(W\) for your latency/VRAM target; start \(S=4\) if unsure.
- Pin indices \(1..S\) in your KV cache; evict others page-wise.
- Log sink mass per head; verify stability on dull segments.
- Quantization? Keep \(1..S\) at higher precision if needed.
- Load test: long, low-entropy transcripts (hours) without collapse.
References (for further reading)
- MIT HAN Lab blog: How Attention Sinks Keep Language Models Stable.
- Community and academic analyses on attention sinks, register tokens, and softmax-driven stability in long-context Transformers.
Note: The parts in white above are my own insights and reflections :D. I’ve reshaped piece with the assistance of large language models.