KV cache is a memory optimization technique used in transformer models during text generation to avoid redundant computations.
Generating "The cat sat on"
Step 1: Generate "cat"
- Input: "The"
- Compute K₁, V₁ for "The"
- Cache: K=[K₁], V=[V₁]
Step 2: Generate "sat"
- Input: "The cat"
- Compute K₂, V₂ for "cat"
- Cache: K=[K₁, K₂], V=[V₁, V₂]
- Reuse K₁, V₁ (no recomputation!)
Step 3: Generate "on"
- Input: "The cat sat"
- Compute K₃, V₃ for "sat"
- Cache: K=[K₁, K₂, K₃], V=[V₁, V₂, V₃]Definition: For each channel (column), selectively remove some token entries
How it works:
Example:
Original:
Ch1 Ch2 Ch3 Ch4 Ch5 Ch6
Token1: [a, b, c, d, e, f]
Token2: [g, h, i, j, k, l]
Token3: [m, n, o, p, q, r]
Token4: [s, t, u, v, w, x]
After Per-Channel Pruning:
Ch1 Ch2 Ch3 Ch4 Ch5 Ch6
Token1: [a, -, c, d, -, f]
Token2: [g, h, -, -, k, l]
Token3: [-, n, o, p, q, -]
Token4: [s, -, u, v, -, x]Definition: For each token (row), selectively remove some channel entries
How it works:
Example:
Original:
Ch1 Ch2 Ch3 Ch4 Ch5 Ch6
Token1: [a, b, c, d, e, f]
Token2: [g, h, i, j, k, l]
Token3: [m, n, o, p, q, r]
Token4: [s, t, u, v, w, x]
After Per-Token Pruning:
Ch1 Ch2 Ch3 Ch4 Ch5 Ch6
Token1: [a, b, -, -, e, f] ← 66% kept
Token2: [g, -, i, j, -, l] ← 66% kept
Token3: [m, n, -, p, q, r] ← 83% kept
Token4: [-, t, u, -, w, x] ← 66% kept| Aspect | Per-Channel Pruning | Per-Token Pruning |
|---|---|---|
| Direction | Vertical (across tokens) | Horizontal (across channels) |
| Unit | Channel vector | Token vector |
| Sparsity Pattern | Different for each channel | Different for each token |
| What's Removed | Token entries within channels | Channel entries within tokens |