We’re examining how Llama models manage time and memory inside attention. The core implementation lives in llama/model.py from the Meta Llama codebase—a compact Transformer that wires together rotary embeddings and a KV cache to make long‑context inference practical. I’m Mahmoud Zalt, an AI solutions architect, and we’ll unpack how this file turns raw tensors into an efficient, time‑aware attention pipeline you can reuse in your own systems.
Our goal is to build a precise mental model for Llama’s attention path—how a token flows from embedding to logits, how its position is encoded with RoPE, and how the KV cache lets the model remember thousands of tokens without recomputing history.
The Core Transformer File
The llama/model.py file defines the full Llama Transformer used for both training and inference. It contains configuration, normalization, rotary positional embeddings, attention, feed‑forward layers, and the stacked Transformer module that produces logits.
Project: meta-llama/llama
llama/
├── __init__.py
├── model.py <-- core Transformer definition
├── tokenizer.py
├── train.py / serve.py
└── ...
Call graph (simplified):
Transformer.forward
├─ tok_embeddings(tokens)
├─ freqs_cis slice (RoPE table)
├─ build causal mask
├─ for each TransformerBlock:
│ └─ Attention + FeedForward
├─ norm(h)
└─ output(h) -> logits
llama/model.py.The main components we care about when we talk about time and memory are:
ModelArgs– configuration dataclass, including KV cache limits.precompute_freqs_cisandapply_rotary_emb– rotary positional embedding pipeline.Attention– multi‑head attention with grouped queries and a KV cache.TransformerBlock– pre‑norm attention + feed‑forward with residuals.Transformer– token embeddings, stack of blocks, final norm + projection.
Encoding Time with Rotary Embeddings
Llama does not add positional vectors to token embeddings. Instead, it uses rotary positional embeddings (RoPE) to encode position directly into the geometry of the query and key vectors. Time becomes a rotation, not an extra feature.
Configuration: Bounding How Far Back We Remember
The ModelArgs dataclass captures both architecture and cache limits:
@dataclass
class ModelArgs:
dim: int = 4096
n_layers: int = 32
n_heads: int = 32
n_kv_heads: Optional[int] = None
vocab_size: int = -1 # set by tokenizer
multiple_of: int = 256
ffn_dim_multiplier: Optional[float] = None
norm_eps: float = 1e-5
max_batch_size: int = 32
max_seq_len: int = 2048
max_batch_size and max_seq_len are the hard limits of the model’s "memory" during generation. They set the size of the KV cache per layer and therefore cap how many tokens you can remember per request without reallocation.
Precomputing Time as Complex Phases
RoPE is implemented via complex exponentials. The function precompute_freqs_cis builds a table of unit complex numbers—one for each position and frequency—up to a configured maximum sequence length:
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device)
freqs = torch.outer(t, freqs).float()
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
return freqs_cis
Conceptually, this creates a matrix where each row is a position index and each column is a rotation frequency. Each entry is a point on the complex unit circle whose angle grows linearly with position.
Rotating Queries and Keys
When attention runs, Llama transforms queries and keys into complex pairs, multiplies them by the precomputed phases for the current positions, and converts them back to real tensors. That’s handled by apply_rotary_emb:
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
The helper reshape_for_broadcast lines up freqs_cis with the batch, sequence, head, and feature dimensions, and asserts that the shapes match. The key property here is that rotations are norm‑preserving: Q and K magnitudes stay the same, but their directions rotate in a position‑dependent way. Relative position becomes relative angle between Q and K.
KV Cache: Remembering the Past Efficiently
RoPE tells us how a single position is represented. The KV cache explains how the model keeps all previous positions around without recomputing them at every step. Instead of regenerating keys and values for the entire prefix, Llama stores them once and appends as new tokens arrive.
The Notebook Analogy
A useful way to think about the KV cache is a growing notebook per layer and per head. For each batch element, every time you process a new chunk of tokens, you write their keys and values to the next empty lines in the notebook. Later tokens can read the whole notebook, but you never rewrite old pages.
Allocating the Notebook
The Attention module owns that notebook. In __init__, it pre‑allocates cache tensors sized by max_batch_size and max_seq_len:
self.cache_k = torch.zeros(
(
args.max_batch_size,
args.max_seq_len,
self.n_local_kv_heads,
self.head_dim,
)
).cuda()
self.cache_v = torch.zeros(
(
args.max_batch_size,
args.max_seq_len,
self.n_local_kv_heads,
self.head_dim,
)
).cuda()
This is a deliberate trade‑off: reserve a large, fixed slab of GPU memory up front to avoid per‑request allocations and keep indexing simple ([batch, position, head, dim]).
Writing and Reading from the Cache
On each forward pass, Attention.forward computes Q, K, V for the current chunk, writes K and V into the cache at the correct offset, and then reads all history (past + current) when computing attention scores:
bsz, seqlen, _ = x.shape
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
...
self.cache_k = self.cache_k.to(xq)
self.cache_v = self.cache_v.to(xq)
self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv
keys = self.cache_k[:bsz, : start_pos + seqlen]
values = self.cache_v[:bsz, : start_pos + seqlen]
The slice start_pos : start_pos + seqlen is the new page being written; : start_pos + seqlen is the full notebook seen by the current chunk. The cache never changes shape during a run—only which part of it is filled.
Grouped‑Query Attention with repeat_kv
Llama often uses fewer KV heads than query heads (n_kv_heads < n_heads) to reduce memory. This is a grouped‑query or multi‑query attention pattern, where several query heads share the same KV head group. The helper repeat_kv repeats KV heads along the head dimension:
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
bs, slen, n_kv_heads, head_dim = x.shape
if n_rep == 1:
return x
return (
x[:, :, :, None, :]
.expand(bs, slen, n_kv_heads, n_rep, head_dim)
.reshape(bs, slen, n_kv_heads * n_rep, head_dim)
)
In our notebook analogy, this is equivalent to multiple readers sharing the same notes: you don’t create new KV entries, you just let more query heads attend to the existing ones.
Causal Masking with a Growing Cache
The Transformer module has to ensure each token only reads from the past and itself, never from the future. With a cache, the score matrix for the current chunk has shape (seqlen, cache_len + seqlen), so the causal mask needs to account for both the already‑cached prefix and the current block.
@torch.inference_mode()
def forward(self, tokens: torch.Tensor, start_pos: int):
_bsz, seqlen = tokens.shape
h = self.tok_embeddings(tokens)
self.freqs_cis = self.freqs_cis.to(h.device)
freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]
mask = None
if seqlen > 1:
mask = torch.full(
(seqlen, seqlen), float("-inf"), device=tokens.device
)
mask = torch.triu(mask, diagonal=1)
mask = torch.hstack([
torch.zeros((seqlen, start_pos), device=tokens.device),
mask,
]).type_as(h)
for layer in self.layers:
h = layer(h, start_pos, freqs_cis, mask)
The zeros on the left of mask correspond to the fully visible cached prefix; the upper‑triangular block forbids attention to future tokens within the current chunk. Combined with the KV cache, this enforces strict causality while still letting every step see the full history.
Design Constraints and Refactors
Once the happy path is clear—RoPE encodes time, the cache stores history, the mask enforces causality—we can look at the pragmatic constraints the implementation introduces, and how the original report suggests tightening them up.
Device‑Agnostic Caches
In Attention.__init__, the KV caches are allocated directly on CUDA with .cuda(). That’s fine for GPU‑only deployment, but it fights model.to(device), makes CPU‑only testing awkward, and bakes a specific accelerator into your model definition.
| Aspect | Current Design | Refactored Design |
|---|---|---|
| Allocation | Ad‑hoc tensors on CUDA in __init__ |
Registered buffers moved by model.to(device) |
| Portability | Tied to GPUs | Works on any PyTorch device |
| Testing | Requires CUDA hardware | CPU tests possible |
The refactor is to turn cache_k and cache_v into registered buffers and avoid hard‑coding CUDA in the constructor. In forward, you still ensure they match the device and dtype of the query tensor, but you no longer fight the framework’s device semantics.
Explicit Cache Bounds
The cache indexing relies on the caller respecting max_batch_size and max_seq_len. If you accidentally send a larger batch or longer context, you get subtle indexing bugs or shape mismatches instead of a clear error.
The suggested change is to add explicit checks in Attention.forward before writing into the cache, comparing the current batch size and start_pos + seqlen against the cache shape. That turns silent misuse into immediate, debuggable failures, without touching the core algorithm.
Training vs. Inference Paths
Transformer.forward is decorated with @torch.inference_mode(), which disables gradient tracking. That’s exactly what you want for serving, but it makes this method unsuitable for training.
The report’s pattern is to extract a shared _forward_impl that contains the actual computation, then keep forward as a thin, inference‑only wrapper around it. Training code calls _forward_impl inside a gradient‑enabled context. This keeps the public inference API simple, while making the execution mode explicit.
Concurrency: One Cache per Story
The KV cache is mutable state shared across calls for a given Transformer instance. If you try to use the same model object concurrently from multiple threads or async tasks, you will interleave writes into the same cache and corrupt each sequence’s history.
What to Steal for Your Own Models
Llama’s core model file shows a clean, pragmatic answer to the question this article started with: how do you let a Transformer remember thousands of tokens without drowning in computation and memory? You encode time as rotations on Q/K with RoPE, and you keep the past in a fixed‑shape KV cache that grows logically but not physically.
- Make time a geometric property. Rotary embeddings push positional information into the angles of Q and K instead of into separate positional vectors. This keeps the architecture simple and makes relative position differences intrinsic to attention scores.
-
Treat the KV cache as a first‑class API concept.
Pre‑allocate it, bound it with explicit config (
max_batch_size,max_seq_len), guard it with assertions, and be honest about its mutability and concurrency model. The cache is not an implementation detail—it’s how the model remembers. - Align implementation with runtime realities. Device‑agnostic buffers, clear separation between training and inference paths, and cache shapes tuned to your workload make the difference between a research model and a production system.
When you design or refactor Transformer‑style systems, start from the same questions Llama’s model.py answers: How is time represented? Where is the past stored? What are the hard limits of that storage? And how does the code make those contracts obvious to the next engineer who reads it, including you six months from now?
Once those answers are clear, you can scale sequence lengths and throughput without losing control over correctness or cost—exactly the balance Llama strikes in its treatment of time and attention.





