Skip to home
المدونة

Zalt Blog

Deep Dives into Code & Architecture at Scale

When Transformers Learn To Listen

By محمود الزلط
Code Cracking
30m read
<

When transformers learn to listen, they stop being just text models and become full speech partners. Curious how that shift changes what we build? 🎧

/>
When Transformers Learn To Listen - Featured blog post image

We often talk about transformers as text engines, but Whisper’s core model is a reminder that the same machinery can listen just as well as it reads. In this walkthrough, we’ll unpack how a surprisingly compact Python file wires convolutions, attention, caching, and alignment into a production‑grade speech‑to‑text brain—and what we can learn from its design.

I’m Mahmoud Zalt, and together we’ll use this file as a case study in building a clean, scalable transformer encoder–decoder that has to run fast in the wild, not just look pretty on paper.

The Model Sitting Quietly in the Middle

Before we dive into layers and tensors, it helps to see where this file lives in the bigger picture. Whisper’s model.py isn’t a CLI, a training loop, or a data loader. It’s the model layer: the core brain every other piece of the system calls into.

project-root/
  whisper/
    __init__.py
    decoding.py
    transcribe.py
    model.py   <-- defines core Whisper transformer
      - ModelDimensions
      - LayerNorm, Linear, Conv1d wrappers
      - MultiHeadAttention
      - ResidualAttentionBlock
      - AudioEncoder (encoder stack)
      - TextDecoder (decoder stack)
      - Whisper (top-level model: exposes decode, detect_language, transcribe)
Figure 1. model.py as the pure model nucleus; decoding and transcription live beside it, not inside it.

That separation is intentional. This file only knows about tensors, shapes, and model dimensions. Everything else—language detection, beam search, CLI behavior—stays in neighboring modules like decoding.py and transcribe.py. The result is high cohesion (everything here is about the model) and low coupling (no I/O, no argument parsing).

The central story in this file is how to turn a dense research‑grade transformer into a practical, production‑ready speech model without drowning in complexity.

The main character in that story is the Whisper class, which takes a single dataclass, ModelDimensions, and wires together an audio encoder, a text decoder, attention blocks, and a few carefully chosen convenience methods: embed_audio, logits, forward, decode, detect_language, and transcribe.

To understand what this model gets right—and where it hides sharp edges—we’ll first walk the encoder path, then the decoder, then zoom into attention, caching, and alignment.

From Spectrograms to Transformer States

Whisper doesn’t consume waveforms directly at this layer. Instead, it expects mel spectrograms—a time × frequency representation of audio—shaped as (batch_size, n_mels, n_ctx). The AudioEncoder turns this into the dense sequence of states the decoder will later attend to.

At a high level, the encoder does three things:

  1. Two 1D convolutions with GELU activation to process and downsample time.
  2. Add a fixed sinusoidal positional embedding.
  3. Feed the resulting sequence through a stack of transformer blocks.
class AudioEncoder(nn.Module):
    def __init__(
        self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
    ):
        super().__init__()
        self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)
        self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
        self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))

        self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
            [ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)]
        )
        self.ln_post = LayerNorm(n_state)

    def forward(self, x: Tensor):
        x = F.gelu(self.conv1(x))
        x = F.gelu(self.conv2(x))
        x = x.permute(0, 2, 1)

        assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape"
        x = (x + self.positional_embedding).to(x.dtype)

        for block in self.blocks:
            x = block(x)

        x = self.ln_post(x)
        return x
Figure 2. AudioEncoder: two convs, a hard assertion, then a standard transformer stack.

That assertion is subtle but important. It ensures the time dimension after convolutions exactly matches the length of the registered positional embedding. If you feed in mel spectrograms with the wrong context length, the model doesn’t try to be clever—it fails fast with "incorrect audio shape".

The positional embedding itself is built using classic sinusoidal embeddings:

def sinusoids(length, channels, max_timescale=10000):
    """Returns sinusoids for positional embedding"""
    assert channels % 2 == 0
    log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
    inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
    scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
    return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
Figure 3. Fixed sinusoidal positions: no training required, always the same for a given config.

Using fixed sinusoids here has a practical upside: the encoder’s notion of “time” is entirely determined by ModelDimensions. There are no extra parameters to load or save, and the positional buffer is registered once and reused on every forward pass.

The cost of this design is rigidity. The encoder assumes a fixed n_audio_ctx; push it beyond that and you need to change ModelDimensions and retrain. For a deployment‑oriented model, that’s a deliberate trade‑off: predictable performance over arbitrary flexibility.

Teaching the Decoder To Listen

Once the encoder has produced a sequence of audio features, the TextDecoder turns token IDs into logits, conditioning on that audio. Conceptually, we have three ingredients:

  • A learned token embedding + positional embedding.
  • A stack of residual attention blocks, each with self‑attention and cross‑attention.
  • A final projection that reuses the token embedding weights (weight tying).
class TextDecoder(nn.Module):
    def __init__(
        self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
    ):
        super().__init__()

        self.token_embedding = nn.Embedding(n_vocab, n_state)
        self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))

        self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
            [
                ResidualAttentionBlock(n_state, n_head, cross_attention=True)
                for _ in range(n_layer)
            ]
        )
        self.ln = LayerNorm(n_state)

        mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)
        self.register_buffer("mask", mask, persistent=False)

    def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
        offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
        x = (
            self.token_embedding(x)
            + self.positional_embedding[offset : offset + x.shape[-1]]
        )
        x = x.to(xa.dtype)

        for block in self.blocks:
            x = block(x, xa, mask=self.mask, kv_cache=kv_cache)

        x = self.ln(x)
        logits = (
            x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
        ).float()

        return logits
Figure 4. TextDecoder: causal self‑attention over tokens plus cross‑attention over audio features.

There are two notable details here.

First, the causal mask. It is precomputed as a buffer of shape (n_ctx, n_ctx), with -inf above the diagonal. When passed into attention, those -inf entries ensure tokens can’t attend to the future. This is what makes decoding autoregressive: position i can only see positions ≤ i.

Second, the offset. When a key–value (KV) cache is used, the decoder might be called multiple times with additional tokens each time. The offset is the length of the cached sequence so far. Instead of always using positions starting at 0, the decoder slices the learned positional embedding to start at offset. That way, token 101 gets the same positional embedding whether you decode all 101 tokens in one shot or in 101 steps.

Notice how the TextDecoder API stays honest: it takes two tensors—x for tokens, xa for encoded audio—and returns logits. It doesn’t know about beam search or temperature; those concerns are delegated to whisper.decoding, keeping the model pure.

Attention That Respects the Hardware

So far we’ve treated attention as a black box. The interesting part of Whisper’s implementation is that it tries to balance mathematical clarity with hardware efficiency. It does this with a custom multi‑head attention module that can optionally switch to PyTorch’s fused scaled dot‑product kernels.

class MultiHeadAttention(nn.Module):
    use_sdpa = True

    def __init__(self, n_state: int, n_head: int):
        super().__init__()
        self.n_head = n_head
        self.query = Linear(n_state, n_state)
        self.key = Linear(n_state, n_state, bias=False)
        self.value = Linear(n_state, n_state)
        self.out = Linear(n_state, n_state)

    def qkv_attention(
        self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        n_batch, n_ctx, n_state = q.shape
        scale = (n_state // self.n_head) ** -0.25
        q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
        k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
        v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)

        if SDPA_AVAILABLE and MultiHeadAttention.use_sdpa:
            a = scaled_dot_product_attention(
                q, k, v, is_causal=mask is not None and n_ctx > 1
            )
            out = a.permute(0, 2, 1, 3).flatten(start_dim=2)
            qk = None
        else:
            qk = (q * scale) @ (k * scale).transpose(-1, -2)
            if mask is not None:
                qk = qk + mask[:n_ctx, :n_ctx]
            qk = qk.float()

            w = F.softmax(qk, dim=-1).to(q.dtype)
            out = (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)
            qk = qk.detach()

        return out, qk
Figure 5. Multi‑head attention: one path for fused SDPA, another for explicit softmax attention.

This function is called in every encoder and decoder layer, so it’s the main hot path. A few things stand out:

  • It reshapes q, k, and v into (batch, heads, time, head_dim) and back, matching the conventional multi‑head layout.
  • When scaled_dot_product_attention is available, it uses that, letting PyTorch handle kernel fusion and memory optimizations.
  • When it falls back, it computes qk explicitly, applies the mask, softmaxes, and forms the weighted sum.

The performance profile in the report highlights this as the central cost: attention is O(batch * heads * n_ctx^2 * d_head) in both time and memory. The SDPA path doesn’t change that asymptotically, but it reduces constants dramatically.

There is, however, a design smell hiding here: MultiHeadAttention.use_sdpa is a class attribute used as a global flag and toggled by the disable_sdpa context manager:

@contextmanager
def disable_sdpa():
    prev_state = MultiHeadAttention.use_sdpa
    try:
        MultiHeadAttention.use_sdpa = False
        yield
    finally:
        MultiHeadAttention.use_sdpa = prev_state
Aspect Current Design Suggested Improvement
Configuration Global flag on the class Per‑instance flag self.use_sdpa
Concurrency All instances share the same switch Each module decides independently
Experimentation Hard to mix SDPA and manual attention Easy to mix per layer or per model

In a single‑threaded script, this global toggle is perfectly fine. In a service handling many concurrent requests with a shared model instance, one request entering disable_sdpa() affects all others that run in that window. The report recommends turning use_sdpa into an instance field and adjusting disable_sdpa to operate on a specific module.

This is a recurring lesson: global state is tempting, but per‑instance configuration scales much better, especially once your model leaves the notebook and lands in a server.

KV Cache: The Secret Latency Weapon

Now that we’ve seen how attention works per step, the next question is: how do we make autoregressive decoding fast enough for real‑time or near‑real‑time transcription? Whisper’s answer is a key–value cache wired through PyTorch forward hooks.

The Whisper class exposes this via install_kv_cache_hooks:

class Whisper(nn.Module):
    ...
    def install_kv_cache_hooks(self, cache: Optional[dict] = None):
        cache = {**cache} if cache is not None else {}
        hooks = []

        def save_to_cache(module, _, output):
            if module not in cache or output.shape[1] > self.dims.n_text_ctx:
                # save as-is, for the first token or cross attention
                cache[module] = output
            else:
                cache[module] = torch.cat([cache[module], output], dim=1).detach()
            return cache[module]

        def install_hooks(layer: nn.Module):
            if isinstance(layer, MultiHeadAttention):
                hooks.append(layer.key.register_forward_hook(save_to_cache))
                hooks.append(layer.value.register_forward_hook(save_to_cache))

        self.decoder.apply(install_hooks)
        return cache, hooks
Figure 6. KV cache hooks: retrofitting efficient incremental decoding onto a standard transformer stack.

Here’s what’s happening:

  1. We walk the decoder and, for every MultiHeadAttention layer, attach hooks to its key and value projection modules.
  2. Each time those projections run, save_to_cache either initializes the cache entry or appends the new time steps along dimension 1.
  3. On the next decoding step, attention can reuse these cached keys/values instead of recomputing them for the whole prefix.

The performance report calls out this path as a hot spot for long sequences, but also a major latency win when used properly. That’s why one of the suggested observability metrics is whisper_decoder_token_latency_ms with a target P95 under 10 ms per token on typical hardware.

There is a subtle behavioral contract in save_to_cache: once the output’s time dimension exceeds n_text_ctx, the cache is replaced instead of concatenated. That prevents unbounded growth, but the semantics aren’t obvious from the API alone. The report suggests either enforcing n_text_ctx strictly (by raising) or documenting this behavior clearly so callers don’t assume infinite history.

Combined with the decoder’s offset logic, this caching machinery turns a quadratic‑per‑token attention pattern into something much closer to linear in sequence length, at least in practice. This is what makes Whisper responsive even on long utterances.

Alignment Heads and Hidden Contracts

So far we’ve focused on the main forward path. Whisper also needs to align tokens to timestamps, and it does that by designating some decoder attention heads as “alignment heads”. This is implemented as a sparse buffer on the Whisper class.

By default, the last half of decoder layers are considered alignment‑capable:

all_heads = torch.zeros(
    self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool
)
all_heads[self.dims.n_text_layer // 2 :] = True
self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False)

For advanced use cases, there’s a way to override this set via a compact binary encoding:

def set_alignment_heads(self, dump: bytes):
    array = np.frombuffer(
        gzip.decompress(base64.b85decode(dump)), dtype=bool
    ).copy()
    mask = torch.from_numpy(array).reshape(
        self.dims.n_text_layer, self.dims.n_text_head
    )
    self.register_buffer("alignment_heads", mask.to_sparse(), persistent=False)

This code is elegant in its concision, but it hides a fairly complex contract:

  • The dump must be base85‑encoded, gzipped, and contain a boolean array.
  • The total number of elements must be exactly n_text_layer * n_text_head.
  • If any of that is off, you get a cryptic reshape or decoding error.

The report flags this as a “complex implicit contract”. The suggested refactor is simple but powerful: validate the decoded array size before reshaping and raise a descriptive ValueError when it doesn’t match expectations. That turns a mysterious runtime failure into an actionable configuration error.

This section of the file also demonstrates a pattern Whisper uses elsewhere: buffers for structural data (masks, position embeddings, alignment heads) that travel with the model’s weights but don’t participate in gradient updates. It’s a clean way to keep model‑shape metadata attached to the module itself.

Hard Lessons From a Soft Interface

We’ve walked the main flow—audio in, tokens out—and peeked into attention, caching, and alignment. Let’s zoom back out and look at the big lessons developers can take from this file when building their own models or integrating Whisper.

Lesson 1: Shape contracts are part of your API

AudioEncoder uses a hard assertion to guard against mismatched audio context. Most other entry points, like embed_audio and logits, assume the caller will pass correctly shaped tensors. When that assumption breaks, PyTorch emits generic shape errors.

The report recommends adding explicit validation in these methods—checking mel.ndim, mel.shape[1] against dims.n_mels, ensuring tokens.ndim == 2, and validating the audio features shape. This has almost no runtime cost but dramatically improves developer experience when integrating the model.

In other words, treat shapes and dtypes as part of your public API surface and fail fast with clear messages when they’re wrong.

Lesson 2: Don’t hide global switches in helpers

The disable_sdpa context manager is convenient, but because it flips a class‑level flag, it effectively changes the behavior of every attention layer in every instance of MultiHeadAttention in the process.

For small scripts this is a non‑issue. For long‑running services, it introduces a race: one request can accidentally slow down another simply by wrapping a decode call in disable_sdpa(). The suggested refactor—to move use_sdpa to instances—changes this from a global to a local concern.

As a general pattern, any time you introduce a global knob for performance or behavior, ask how it behaves under concurrency and whether you’d be better served by a per‑instance or per‑call parameter.

Lesson 3: Performance optimizations need observability

Whisper’s model code already includes the hooks needed to make decoding fast: SDPA integration and a KV cache. But the report goes further, recommending concrete metrics:

  • whisper_encoder_forward_latency_ms to catch regressions in the audio encoder.
  • whisper_decoder_token_latency_ms to understand user‑visible latency.
  • whisper_attention_memory_bytes and whisper_kv_cache_size_bytes to detect OOM risks as context lengths or batch sizes grow.

The underlying idea is simple: never ship a performance optimization that you can’t observe. Without metrics, it’s hard to know whether SDPA is actually used, whether caches are growing as expected, or why latency spikes under certain workloads.

Lesson 4: Keep the model pure, the rest can follow

One of the most elegant choices in this file is what it doesn’t do. The Whisper class exposes:

  • embed_audio for encoder‑only passes,
  • logits and forward for core model evaluation, and
  • aliases to decode, detect_language, and transcribe from neighboring modules.

But it never reaches out to files, sockets, or CLIs. Inputs and outputs are always plain tensors. That purity makes the model safe to use in everything from research notebooks to high‑throughput services and simplifies testing: you can exercise almost everything with small synthetic tensors.

Lesson 5: Small details preserve numerical health

Finally, a quieter but important theme: type handling. Whisper wraps PyTorch’s LayerNorm, Linear, and Conv1d to cast weights and activations carefully, normalizing in float32 but returning results in the input dtype. This is crucial for mixed‑precision inference where some layers may run in float16 or bfloat16.

It’s easy to overlook these “plumbing” details, but they reduce subtle numerical issues and make it more likely that the model behaves consistently across hardware configurations.

Bringing it home

Whisper’s model.py is more than a transformer implementation. It’s a compact blueprint for turning a research architecture into something you can embed into real systems: careful about shapes, pragmatic about performance, and disciplined in what it owns.

If you’re designing your own model stack, a few concrete actions to borrow today are:

  • Introduce a single configuration object (like ModelDimensions) that fully describes your model’s shape.
  • Add explicit, descriptive input validation at the edges of your public API.
  • Make performance toggles (like SDPA vs. manual attention) per‑instance, not global.
  • Expose observability hooks—latency and memory metrics—for your hot paths.
  • Keep the model pure: tensors in, tensors out; push everything else to a higher layer.

When transformers learn to listen, as Whisper does here, it’s not only the architecture that matters. It’s the engineering discipline around that architecture that turns a paper idea into a reliable tool.

Full Source Code

Here's the full source code of the file that inspired this article.
Read on GitHub

Unable to load source code

Thanks for reading! I hope this was useful. If you have questions or thoughts, feel free to reach out.

Content Creation Process: This article was generated via a semi-automated workflow using AI tools. I prepared the strategic framework, including specific prompts and data sources. From there, the automation system conducted the research, analysis, and writing. The content passed through automated verification steps before being finalized and published without manual intervention.

Mahmoud Zalt

About the Author

I’m Zalt, a technologist with 15+ years of experience, passionate about designing and building AI systems that move us closer to a world where machines handle everything and humans reclaim wonder.

Let's connect if you're working on interesting AI projects, looking for technical advice or want to discuss your career.

Support this content

Share this article