
The Token-and-Duration Transducer (TDT) extends RNN-T by jointly predicting what token to emit and how many frames that token covers. This lets the model skip multiple encoder frames per step during inference instead of advancing one at a time, yielding up to 2.82x faster decoding with comparable or better accuracy.
Word Error Rate (WER) is a useful metric to try to optimise, but if your model takes 10 seconds to transcribe 1 second of audio, nobody's shipping it. The Huggingface Open ASR Leaderboard tracks both accuracy and speed. At the time of writing, in the huggingface top 10, Nvidia's Parakeet TDT models are more than 3x ahead of the nearest competition in RTFx (Inverse Real Time Factor/Throughput, i.e. how many seconds of audio the model can process per second of wall-clock time).
These models are significantly faster than the competition while maintaining competitive WERs. The mechanism? A modification to the RNN-Transducer called the Token-and-Duration Transducer (TDT). In this post, we'll first look at how RNN-T and TDT work at inference time to build intuition for why TDT is faster, then circle back to explain how each model is trained.
Without going into too much detail, there are a few ways to train a speech-to-text model: CTC, AED, Decoder-only or RNN-T/TDT. Each of these have pros and cons, for a full comparison see Desh's Analysis.
RNN-T hits a useful middle ground: it has enough modeling capacity to capture label dependencies (unlike CTC), its autoregressive component is lightweight (unlike AED/Decoder-only), and it can be trained end-to-end with a well-understood loss function. RNN-T is already fast - much faster than AED or Decoder-only models, and only somewhat slower than CTC. But there's still space to speed up.
An RNN-T consists of three components:
Here t indexes the encoder time-step, u indexes how many output labels have been emitted so far (position in the target sequence), and v denotes any candidate vocabulary symbol (including blank) when we write P(v∣t,u).
At inference time, the model decodes greedily by stepping through encoder frames:
# RNN-T Greedy Decoding (simplified) t = 0 u = 0 output = [] while t < T: logits = joint(encoder[t], predictor(output)) token = argmax(logits) if token == BLANK: t += 1 # advance ONE frame else: output.append(token) # stay at same t, advance u
At each step, the model either emits blank (advance one frame) or emits a token (stay at the same frame, advance the label index). The key observation in speeding this up will be to allow frame-skipping.
For a 10-second utterance at 80ms frame rate (after subsampling), that's ~125 sequential joint network calls at minimum. Most of those will be blanks - in typical speech, tokens are sparse relative to frames. The model spends most of its time predicting "nothing is happening" one frame at a time. The joint network is cheap per call, but the sequential one-frame-at-a-time structure leaves performance on the table.
TDT addresses this.
The core idea of TDT (Xu et al., 2023): instead of predicting just a token at each step, jointly predict the token and how many frames it covers.
In standard RNN-T, the joint network outputs a single distribution over ∣V∣+1 symbols (vocabulary + blank). In TDT, the joint network outputs two independent distributions:
where D is a predefined set of durations. A typical choice is D={0,1,2,3,4}, though the set can be configured - for example, {1,2,3,4} (omitting 0) is also valid.
The two heads share the same encoder and predictor representations but are independently normalized (separate softmax operations):
# TDT Joint Network Output logits = joint(encoder[t], predictor(output)) # shape: [V + 1 + |D|] # Split into token and duration logits token_logits = logits[:V+1] # shape: [V + 1] duration_logits = logits[V+1:] # shape: [|D|] # Independent softmax token_probs = softmax(token_logits) duration_probs = softmax(duration_logits)
The inference speedup is immediate. Compare with the RNN-T loop above:
# TDT Greedy Decoding (simplified) t = 0 output = [] while t < T: logits = joint(encoder[t], predictor(output)) token = argmax(token_logits) duration = argmax(duration_logits) if token == BLANK: t += max(1, duration) # skip MULTIPLE frames! else: output.append(token) t += duration # can also skip frames on token emission
Instead of advancing one frame at a time, the model can skip over stretches of silence or steady-state audio. If the model predicts blank with duration 4, it skips 4 frames in one step - reducing joint network calls for that stretch proportionally.
Let's trace through a concrete example. Suppose we have 8 encoder frames (T=8), target "hi" → tokens [h, i], and durations D={0,1,2,3}:
t=0: joint(enc[0], pred([])) → token=h (p=0.8), duration=0 (p=0.7) → emit 'h', stay at t=0 t=0: joint(enc[0], pred([h])) → token=i (p=0.6), duration=2 (p=0.5) → emit 'i', jump to t=2 t=2: joint(enc[2], pred([h, i])) → token=blank (p=0.9), duration=3 (p=0.6) → skip to t=5 t=5: joint(enc[5], pred([h, i])) → token=blank (p=0.95), duration=3 (p=0.8) → skip to t=8 → DONE!
4 joint network calls instead of 8+ for standard RNN-T. That's the speedup.
The TDT paper reports up to 2.82x faster inference than standard RNN-T on speech recognition tasks, with comparable or better accuracy. The speedup is more pronounced on longer utterances with more silence.
Now that we've seen what these models do at inference time, let's understand how they're trained. This requires a bit more machinery.
During training, have the audio X and we know the correct transcription y, but typically don't know the correct frame-word alignment.
Suppose we have T=8 encoder frames and the target transcription
we have many potential ways to get the exact same transcript, for example:
Path A (early speech): ∅, the, quick, brown, fox, ∅, ∅, ∅, ∅, ∅, ∅, ∅ (orange) → all tokens emitted by t=4, rest is silence Path B (spread out): the, ∅, ∅, quick, ∅, brown, ∅, ∅, fox, ∅, ∅, ∅ (pink) → tokens spread across the utterance Path C (late speech): ∅, ∅, ∅, ∅, ∅, the, quick, brown, ∅, ∅, fox, ∅ (blue) → speech starts late, around t=5
Remember that each time we output a blank symbol ∅, we increment t (the time-frame of the encoder) and each time we output a token, we feed that back into the predictor to get the next predictor output (increment u by one).
Our goal now is to maximise the chance of the correct transcript (irrespective of the alignment - which we don't yet know). RNN-T's solution to this is maximise the probability over all possible alignments. The way we visualise this is by constructing a lattice, which will encode any possible frame-word alignment.
The joint network produces a probability distribution P(v∣t,u) at every node (t,u) in a T×(U+1) grid (the lattice), where T is the number of encoder frames and U is the number of target tokens. In the above example, for t=3 and u=2, we evaluate:
the joiner called on the 3rd frame of the encoder output, and the predictor called on the first 2 model outputs. This gives us a probability distribution over the entire vocab, V, plus blank, ∅. For training, we only care about the probability of the next correct token (in this case "brown") or blank, ∅ - so we just show these two transitions in the lattice:
Every valid path from bottom-left [start] to top-right [end] emits exactly the target sequence and is a different valid alignment. Different paths through this lattice correspond to different timings of the same transcription.
To get the probability of a given path/alignment we use the product of all token/blank probabilities along that path. e.g.
Where (as descibed above), t is the time-frame index of the encoder output, and u is the amount of the transcript that the predictor has seen so far.
The total probability of y (the correct transcription) is defined as the sum over all such paths:
where A(y) is the set of all valid alignments for y.
This probability P(y∣X) is the objective we will try to maximise in training. Or more accurately, we will try to minimise the negative log-likelihood:
So, our loss is completely agnostic to the alignment the model wants to use, we just want to maximise the total probability mass running through this lattice.
Now we need to efficiently calculate this loss, LRNNT - and the relevant gradients.
As long as we stay on this training lattice, we will produce the correct transcript. The probability of staying on this lattice is the thing we will try to maximise - so we want to boost the chance of any transitions on this lattice (scaled by the impact they have on the final probability).
It's worth noting here that the output of the joiner is normalised, so increasing the chance of e.g. the token y1=“the", will implicitly decrease the chance of all other tokens here e.g. y1=“then".
So how do we get this probability?
If we start from no transcript - with probability 1 (we must start with nothing yet transcribed) - we can get the chance of moving in either valid direction:
So, the chance of going up in the lattice - emitting z0,0=“the" - is say 0.4. The chance of emitting z0,0=∅ is say 0.5. This is quite good, it means the chance of emitting any other random token is only 0.1 - indicating a well trained model.
Here we'll keep α(t,u) as the probability of getting to a node (from the start). So, what's the chance of progressing any further through this lattice:
To get to node (t=2,u=1) i.e. two blanks and one correct “the" token. We have three possible paths:
Path 1: "the", ∅, ∅ (↑, →, →) P_1 = 0.4 * 0.1 * 0.1 = 0.004 Path 2: ∅, "the", ∅ (→, ↑, →) P_2 = 0.5 * 0.5 * 0.1 = 0.025 Path 3: ∅, ∅, "the" (→, →, ↑) P_3 = 0.5 * 0.4 * 0.7 = 0.14
So the sum over all paths to node (t=2,u=1) is
This means that the rest of the time: 100%−16.9%=83.1% of the time, we've already gone wrong at this stage - left the training lattice - e.g. output ["then", ∅, ∅] or [∅, "apple", ∅].
More generally, we define α, a.k.a. the forward variable, as the sum of all correct the paths to a given node:
It's also useful to think of this as the total amount of probability mass that flows through the lattice to a given node.
Now if we enumerate all paths to the [end] node and sum the probabilities we will get the full transcript probability:
The problem with this is that we will have way too many paths to enumerate. Even for the above small example, with T=8 and U=4 we have 330 potential paths through the lattice.
To solve this issue, we notice that to get the probability mass that gets to a given node, we only care about the mass that gets to the previous adjacent nodes (i.e. one blank token backwards, or one correct token backwards):
We don't care about individual paths leading up to these predecessor nodes, just the total sum over all possible paths to them - the total probability mass that arrives there. This means we get the following:
with α(0,0)=1 (as the chance of starting at (0,0) is 100%). Each term above says: the mass arriving at (t,u) is the mass at the predecessor node, times the probability of the transition from the predecessor to (t,u). This means that we get to skip enumerating every possible path and just run through each node in the lattice with this sum - all the way to the [end] node.
Now that we've efficiently calculated the total probability P(y∣X)=α(T,U), we need to calculate the amount that each transition effects this final sum - the gradient of P(y∣X) with respect to zt,u. This will tell us how much to update the model weights each step. Specifically, if we make a small change in transition probabilities ∂zt,u, what will be the effect on the total probability ∂P(y∣X).
So let's work this out for a given node; e.g. probability of a blank transition at (t=3,u=1): P(∅∣t=3,u=1). This means that the model has already output the correct first token - e.g. "the" as well as 3 blank tokens ∅ - in some order.
For some path "Path k" through the lattice - that goes through our transition z3,1=∅, we have:
where P(vi) is some transition that exists along this path. This means that a small change in our transition ∂P(∅∣t=3,u=1) will affect the total path probability:
So, this is the amount that changing our transition probability will affect a path that uses it - just the total probability of the path (not including the given transition). We also observe that changing this probability won't affect paths that don't use this transition. We also know that the total probability of the correct transcript is the sum over all possible correct paths:
so naturally, the effect of changing this transition on P(y∣X), is the sum of the effects it has on each relevant path:
which we can split into the sum over paths that get to the transition z3,1=∅; the transition itself; and the sum over all the paths that leave the transition and get to the [end] state:
All this is saying is: the amount that the total probability changes is the amount of probability mass that gets to a given transition × the amount of probability mass that gets from the transition to the [end] state.
But we've already done the maths for the first part! The sum over paths that get to a given node is just α(t,u), and the second part looks very similar - we'll call this the backward variable β(t′,u′).
The backward variable β(t,u) is the mirror image of α(t,u): it represents the total probability mass that flows from node (t,u) to the final state - "how much probability mass will still reach the target from here."
Here, in a symmetric way, we set β(T,U)=1, but now walk backwards through the lattice:
Effectively, "the amount of probability mass that will reach the target from each of the next nodes" × "the probability of getting there from the current node".
This simplifies life a lot for our example:
Or more generally:
Where v is a transition at a given node (t,u) pointing to another node (t′,u′).
The forward variable gets the mass to the transition, and the backward variable represents the mass that will eventually arrive at the target from that point. The full loss gradient normalizes by the total likelihood (see the original Graves 2012 paper for the complete derivation):
This gives a nice result! The gradient with respect to any given transition probability is just the proportion of the total probability mass that flows through that transition. Early in training, when P(y∣X) is small, the gradient is still significant for any correct transitions. This also explains why lattice paths tend to collapse to a small number of dominant alignments later in training - the highest-probability paths receive the largest gradients, incentivizing further path concentration.
n.b. This "path collapse" is a key insight of the K2's RNN-T pruned loss which simplifies the gradient computation significantly by only considering paths near (in time) to the high-probability alignments.
The forward-backward algorithm computes all of this in O(T⋅U) time.
Training TDT requires modifying the forward-backward algorithm to account for the duration variable. The loss is still the negative log-likelihood −logP(y∣X), but the lattice transitions are now richer. Recall that in TDT, transitions can skip multiple frames:
Note the asymmetry: blanks must have d≥1 (you must advance at least one frame when emitting nothing), but tokens can have d=0 if 0 is in D (emitting a token without advancing - useful for fast speech or multi-token emissions at a single frame). If D doesn't include 0, every emission also advances at least one frame.
We now have two independent distributions predicted from each node:
The forward variable α(t,u) now has a more complex recurrence. At each position (t,u), we must sum over all durations that could have led here:
The key difference from standard RNN-T: instead of looking back exactly 1 step, we look back d steps for each duration in D. This makes the forward pass O(T⋅U⋅∣D∣) instead of O(T⋅U) - a constant factor increase since ∣D∣ is typically small (4–5 elements).
The backward variable β(t,u) follows the same pattern but in reverse:
The gradient computation uses both α and β in the standard way, summing over each possible duration for a given token prediction and scaling by the duration probabilities (PD). For the token logit, the gradient at position (t,u,v) is:
where Ct,u is the set of reachable states from (t,u):
In this case, to count the paths affected by the chance of predicting e.g. "fox", we have 4 possible lattice transitions to count, and 3 possible transitions for the blank token ∅.
For the duration logits, the gradient at position (t,u,d) accounts for all transitions that use duration d, either the correct token or a blank transition:
or for d=0 (blank not allowed at zero duration):
This too is somewhat intuitive. It represents the sum over all valid paths that use this duration. Now we're done! This is all the maths required to understand the efficient TDT training mechanics. For the full derivation see the TDT paper.
Working in Log-space: As is usual in machine learning when working with probabilities, we use log-space. Big summations of log probabilities are much more stable than big products of raw probabilities.
The Sigma Trick - Logit Under-Normalization: Every transition in the lattice, whether blank or token, gets penalized by σ (typically 0.05) in log-space. Since this penalty is applied per transition, paths with more steps accumulate a larger total penalty. This biases the model toward using fewer, larger-duration steps rather than many duration-1 steps.
The Omega Trick - Sampled RNN-T Loss: with probability ω, the loss falls back to the standard RNN-T loss (ignoring durations entirely). This acts as a regularizer, ensuring the token predictions remain well-calibrated even without duration information. This is important for the batched inference case, where we will have to increment the entire batch encoder-frame by the same amount (e.g. the minimum predicted token duration).
TDT has the same memory footprint challenge as standard RNN-T: the joint network output is a 4D tensor of shape (B,T,U,V+∣D∣). For large vocabularies and long sequences, this can be enormous. The standard mitigation is fused loss computation - instead of materializing the full joint tensor, compute the loss and gradients in a fused kernel that only materializes one (t,u) slice at a time. Also, it's typically important to keep the vocab-size small - the above example uses full words, but a smaller vocab of sub-words is usually preferable.
The choice of duration set D matters. The paper uses {0,1,2,3,4} as the default. Some considerations:
TDT is related to but distinct from the Multi-Blank Transducer, which adds multiple blank symbols (big-blank-2, big-blank-3, etc.) that skip different numbers of frames. The key difference:
| Multi-Blank | TDT | |
|---|---|---|
| Duration prediction | Implicit (via blank type) | Explicit (separate head) |
| Token durations | Always 0 (no frame skip on token) | Variable (tokens can skip frames too) |
| Vocab size increase | ∣D∣ blank symbols | No vocab increase; separate duration head |
| Independence | Token and duration coupled | Token and duration independently normalized |
TDT's independent normalization means the model doesn't need to use vocabulary capacity on multiple blank symbols, and the duration prediction can be more fine-grained.
TDT extends RNN-T by jointly predicting tokens and their durations. The key ideas are:
The result: models that are up to 2.82x faster at inference with comparable or better accuracy than standard transducers - and RNN-T was already fast to begin with. This is how Nvidia's Parakeet-TDT models dominate the RTFx column at the top of the HuggingFace leaderboard.
The NeMo toolkit has a full implementation, and pretrained Parakeet-TDT checkpoints are available on HuggingFace.
References:
![[alt: Sound waveform overlaid on legal documents representing word error rate in legal transcription]](/_next/image?url=https%3A%2F%2Fimages.ctfassets.net%2Fyze1aysi0225%2FQRSezBsdLCxs1BVUN8hS7%2F2039e32c7e69124576ed85a9fb8f90c5%2Fblog-image-wide-carousel__1_.webp&w=3840&q=75)
Word error rate for legal transcription has no single acceptable threshold. But knowing how accuracy, audio quality, and review obligations connect to real legal risk is what separates a reliable transcript from a costly one.
![[alt: Court reporter shortage carousel]](/_next/image?url=https%3A%2F%2Fimages.ctfassets.net%2Fyze1aysi0225%2F2merK8OIQsF78D6bf8J4k8%2F900485ee565bcce115227fdfc74b2914%2Fblog-image-wide-carousel.webp&w=3840&q=75)
The court reporter shortage is reshaping litigation. Explore data, causes, and how legal teams are using digital reporting and AI transcription to adapt.
![[alt: Healthcare professionals in scrubs and lab coats walk briskly down a hospital corridor. A nurse uses a tablet while others carry patient charts and attend to a gurney. The setting conveys a busy, clinical environment focused on patient care.]](/_next/image?url=https%3A%2F%2Fimages.ctfassets.net%2Fyze1aysi0225%2F3TUGqo1FcOmT91WhT3fgbo%2F9a07c229c11f8cbe62e6e40a1f8682c7%2FImage_fx__8__1-wide-carousel.webp&w=3840&q=75)
As clinical workflows become automated and AI-driven, real-time speech is shifting from a transcription feature to the foundational intelligence layer inside modern EHR systems.
![[alt: Logos of Speechmatics and Edvak are displayed side by side, interconnected by a stylized x symbol. The background features soft, wavy lines in light blue, creating a modern and tech-focused aesthetic.]](/_next/image?url=https%3A%2F%2Fimages.ctfassets.net%2Fyze1aysi0225%2F7LI5VH9yspI5pKWFeiZBXC%2F92f6a47a06ab6a97fb7f5a953b998737%2FCyan-wide-carousel.webp&w=3840&q=75)
Turning real-time clinical speech into trusted, EHR-native automation.
