Skip to content

Parakeet tdt#44171

Merged
ArthurZucker merged 85 commits into
huggingface:mainfrom
lmaksym:parakeet-tdt
May 19, 2026
Merged

Parakeet tdt#44171
ArthurZucker merged 85 commits into
huggingface:mainfrom
lmaksym:parakeet-tdt

Conversation

@lmaksym
Copy link
Copy Markdown
Contributor

@lmaksym lmaksym commented Feb 20, 2026

What does this PR do?

This PR adds TDT decoder support for Parakeet ASR models, extending the existing CTC-only implementation.
It incorporates the initial TDT integration work from #41545 by @hainan-xv (was not merged) and and addresses all review feedback from both #41545 and #43357.

Changes

  • ParakeetForTDT model with greedy TDT decoding in generate()
  • ParakeetTDTDecoder (LSTM prediction network) and ParakeetTDTJointNetwork as nn.Module subclasses
  • Per-token timestamp generation via return_timestamps=True
  • AutoModelForTDT auto class with pipeline, processor, and tokenizer integration
  • Flat ParakeetTDTConfig matching the CTC pattern (no nested decoder/joint configs)
  • Shared ParakeetPreTrainedModel base between CTC and TDT (no separate TDT base class)
  • NeMo-to-HF weight conversion script for TDT models
  • Documentation and tests following existing CTC patterns

Validation

  • 278 unit tests pass, make check-repo passes
  • CTC model unaffected by changes
  • LibriSpeech test-clean: 2.09% WER (matches NVIDIA published ~2-3%)
  • Timestamps validated against commercial ASR (94.3% within 2 frames)
  • Model: MaksL/parakeet-tdt-0.6b-v3

Before submitting

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@ebezzam and @hainan-xv please review

-->

Hainan Xu and others added 2 commits February 20, 2026 09:45
Implement Token-and-Duration Transducer (TDT) decoding for Parakeet models,
extending the existing CTC-only support. This adds ParakeetForTDT with greedy
TDT decoding in generate(), per-token timestamp generation, and full
integration with AutoModelForTDT, processors, and ASR pipeline.
Copy link
Copy Markdown
Contributor

@ebezzam ebezzam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lmaksym thank you putting together the PRs cleanly! I pushed a few changes for adapting to Transformers convention and added integration tests to compare with the original model from NeMo.

@hainan-xv and @nithinraok, your input could be useful for the TDT decoding, and also the loss computation.

Comment thread src/transformers/models/parakeet/modular_parakeet.py Outdated
Comment thread src/transformers/models/parakeet/modular_parakeet.py Outdated
- Use -100 label padding for training (HF convention)
- Fix timestamp recording in inner blank-seeking loop
- Add max_symbols_per_step guard matching NeMo
- Clean up decoding loop
- Add TDT training example to docs
- Use setUpClass for TDT integration tests
Copy link
Copy Markdown

@hainan-xv hainan-xv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left a comment on the loss computation part.

Comment thread src/transformers/models/parakeet/modeling_parakeet.py Outdated
Comment thread tests/models/parakeet/test_modeling_parakeet.py Outdated
Copy link
Copy Markdown
Contributor

@ebezzam ebezzam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lmaksym thanks for porting the TDT loss! it's nice (1) to not have to depend on torchaudio and (2) to make the TDT loss available in Transformers!

It is functional with this example (single GPU): https://gist.github.com/ebezzam/6382bdabfc64bb2541ca9f77deb7678d#file-tdt_training_snippet-py
But quite slow...

I wonder if there is a custom gradient computation in NeMo? As I noticed in the paper (Section 3.1), they say "We derive an analytical solution for the gradient of the TDT loss, since automatic differentiation for transducer loss is highly inefficient."

FYI I can test/fix on my side for multi-GPU compatibility.

Comment thread src/transformers/models/parakeet/modular_parakeet.py Outdated
@lmaksym
Copy link
Copy Markdown
Contributor Author

lmaksym commented Mar 3, 2026

@lmaksym thanks for porting the TDT loss! it's nice (1) to not have to depend on torchaudio and (2) to make the TDT loss available in Transformers!

It is functional with this example (single GPU): https://gist.github.com/ebezzam/6382bdabfc64bb2541ca9f77deb7678d#file-tdt_training_snippet-py But quite slow...

I wonder if there is a custom gradient computation in NeMo? As I noticed in the paper (Section 3.1), they say "We derive an analytical solution for the gradient of the TDT loss, since automatic differentiation for transducer loss is highly inefficient."

FYI I can test/fix on my side for multi-GPU compatibility.

I'll look into that

Copy link
Copy Markdown
Contributor

@ebezzam ebezzam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ArthurZucker thanks for your comments! I've made changes to the cache object and some refactoring in the processor.

For the hub kernel, I can see with Cyril

from ...utils import ModelOutput


class ParakeetTDTDecoderCache:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved cache object to generation



class ParakeetTDTDecoderCache:
def __init__(self, config):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

config passed so we don't need to pass lstm module to update for lazy initialization

Comment on lines +704 to +711
# Get cached hidden/cell states if available, otherwise initialize with ParakeetTDTDecoderCache
if cache is not None:
was_initialized = cache.is_initialized
if not was_initialized:
cache.lazy_initialization(embeddings)
hidden_cell_states = (cache.hidden_state, cache.cell_state)
else:
hidden_cell_states = None
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Related to your comment, I tried to make it clearer how hidden/cell state get initialized. But still a bit clunky...

In case links not redirecting well:

(Arthur)
If not cache.is_initialized but still passed why are we not passing?

(Eric)
I think the confusing part is that we were actually relying on self.lstm to initialize hidden_state and cell_state (rather than ParakeetTDTDecoderCache.lazy_initialization) and then passing them to ParakeetTDTDecoderCache to set them. I'll try coming up with something that's clearer


if cache is not None:
mask = ~blank_mask if was_initialized else None
cache.update(decoder_output, hidden_state, cell_state, mask=mask)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LSTM module no longer passed

)

if use_decoder_cache and decoder_cache is None:
decoder_cache = ParakeetTDTDecoderCache(self.config)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

passing config to help with initialization (instead of using LSTM module)

Comment on lines +137 to +150
stream = DecodeStream(skip_special_tokens=True)
timestamp_dict = []
for i, token_id in enumerate(batch_ids):
if int(token_id) in skip_ids:
continue
chunk = stream.step(self.tokenizer._tokenizer, int(token_id))
if chunk is not None:
timestamp_dict.append(
{
"token": chunk,
"start": int(batch_timestamps[i]),
"end": int(batch_timestamps[i] + batch_durations[i]),
}
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refactored to use DecodeStream but I think we still need the inner loop to associate the individual token to its timestamps?

re this comment

(Arthur) all of this can and potetntially should be done in a single for loop no?

(Eric)
hmm from what I understand of DecodeStream, its step() method must be called per-token to get per-token decoded text for timestamp pairing. Passing a list merges the output into one chunk, losing the 1:1 token-to-timestamp mapping.

Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker May 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah , the stream just simplifies the non valid token handling 😉

Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On nit, let's avoid modifying the general code for conversion

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, but I don't think the PB is here, maybe more in ParakeetConverter(model_files["tokenizer_model_file"]).converted() ?

@ebezzam ebezzam mentioned this pull request May 19, 2026
@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@ebezzam
Copy link
Copy Markdown
Contributor

ebezzam commented May 19, 2026

run-slow: lasr, parakeet

@github-actions
Copy link
Copy Markdown
Contributor

Workflow Run ⚙️

This comment contains run-slow, running the specified jobs:

models: ["models/lasr", "models/parakeet"]
quantizations: []

@github-actions
Copy link
Copy Markdown
Contributor

CI Results

Workflow Run ⚙️

Commit Info

Context Commit Description
RUN 5b0b54bf workflow commit (merge commit)
PR c2e22872 branch commit (from PR)
main c0008d14 base commit (on main)

✅ No failing test specific to this PR 🎉 👏 !

@ebezzam ebezzam enabled auto-merge May 19, 2026 04:54
@ebezzam ebezzam added this pull request to the merge queue May 19, 2026
@github-merge-queue github-merge-queue Bot removed this pull request from the merge queue due to failed status checks May 19, 2026
@tarekziade tarekziade added this pull request to the merge queue May 19, 2026
@github-merge-queue github-merge-queue Bot removed this pull request from the merge queue due to failed status checks May 19, 2026
@github-actions
Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: auto, encodec, lasr, parakeet

@tarekziade tarekziade enabled auto-merge May 19, 2026 06:05
@ebezzam ebezzam disabled auto-merge May 19, 2026 06:15
@github-actions
Copy link
Copy Markdown
Contributor

@ArthurZucker ArthurZucker merged commit 38a8b55 into huggingface:main May 19, 2026
27 of 29 checks passed
yuchenxie4645 pushed a commit to yuchenxie4645/transformers that referenced this pull request May 28, 2026
* parakeet tdt intergration

* Add TDT decoder support for Parakeet ASR models

Implement Token-and-Duration Transducer (TDT) decoding for Parakeet models,
extending the existing CTC-only support. This adds ParakeetForTDT with greedy
TDT decoding in generate(), per-token timestamp generation, and full
integration with AutoModelForTDT, processors, and ASR pipeline.

* Add expected outputs for TDT, small fixes.

* Separate CTC and TDT generate outputs.

* Work with auto device, better init,

* Test timestamps and expose token duration.

* Add reproducer link.

* fix: align TDT training and decoding with NeMo implementation

- Use -100 label padding for training (HF convention)
- Fix timestamp recording in inner blank-seeking loop
- Add max_symbols_per_step guard matching NeMo
- Clean up decoding loop
- Add TDT training example to docs
- Use setUpClass for TDT integration tests

* revert: restore lasr generated files to original state

* warn: torchaudio rnnt_loss does not train duration head

* Relax timestamp test, and test nits.

* feat: TDT training

* chore: for cuda detection and run without patching

* Equivalent timestamp processing as Nemo, and various nits/cleanup.

* Simplify durations config.

* Update training examples.

* chore: enable parralelism

* chore: performance optimization

* fix: formatting

* Doc and testing nits

* Use active mask from current step, and nits.

* Better pre-allocate.

* TDT has separate pad token and blank token.

* Regenerate lasr.

* Style checks and nits

* Nits, put back ctc loss test

* More standard model output.

* Style

* Remove compute_loss flag and allow monkey patching to tdt loss

* Update src/transformers/models/parakeet/modular_parakeet.py

Co-authored-by: eustlb <[email protected]>

* Address various comments.

* More compatible with Transformers forward/generate approach

* compile option for generation and decoder cache

* Cleaner, better conventions.

* Update with main.

* doc nits

* Imitate whisper for encoder outputs as input

* Address tests and nits.

* Inherit from GenerateMixIn for get_compiled_call

* Comment nit

* forward cleanup

* generate cleanup + separate generation file

* generate: add _supported_generation_modes

* automatic init of the loss

* modular cleanups

* use is_encoder_decoder

* timestamp processing fully from tokens + durations

* convertion script update

* test update

* make

* test update

* test update

* ensure correct loss computation

* kernel loss

* test loss integration

* push to hub pr

* integration tests to rely fully on transcripts

* udpate fixtures

* we don't need to monkey patch with numba anymore!

* fix pipeline usage

* nit

* fix usage

* Pass through tests and examples: improve kernel fallback, update with nvidia checkpoint, style checks.

* Update checkpoint

* Add TDT to mapping after merge.

* Fix lasr generate test.

* Output attention mask if labels provided for computing loss.

* Apply suggestion from @ArthurZucker

Co-authored-by: Arthur <[email protected]>

* Improve ParakeetTDTDecoderCache definition and usage.

* Remove tuple parsing.

* processor refactor

* Update conversion.

* Remove kernel to address in separate PR, modular nit.

* Lasr modular

* Add attention mask docstring.

---------

Co-authored-by: Hainan Xu <[email protected]>
Co-authored-by: Eric B <[email protected]>
Co-authored-by: Eric Bezzam <[email protected]>
Co-authored-by: eustlb <[email protected]>
Co-authored-by: Arthur <[email protected]>
Co-authored-by: Tarek Ziade <[email protected]>
kashif pushed a commit to kashif/transformers that referenced this pull request Jun 1, 2026
* parakeet tdt intergration

* Add TDT decoder support for Parakeet ASR models

Implement Token-and-Duration Transducer (TDT) decoding for Parakeet models,
extending the existing CTC-only support. This adds ParakeetForTDT with greedy
TDT decoding in generate(), per-token timestamp generation, and full
integration with AutoModelForTDT, processors, and ASR pipeline.

* Add expected outputs for TDT, small fixes.

* Separate CTC and TDT generate outputs.

* Work with auto device, better init,

* Test timestamps and expose token duration.

* Add reproducer link.

* fix: align TDT training and decoding with NeMo implementation

- Use -100 label padding for training (HF convention)
- Fix timestamp recording in inner blank-seeking loop
- Add max_symbols_per_step guard matching NeMo
- Clean up decoding loop
- Add TDT training example to docs
- Use setUpClass for TDT integration tests

* revert: restore lasr generated files to original state

* warn: torchaudio rnnt_loss does not train duration head

* Relax timestamp test, and test nits.

* feat: TDT training

* chore: for cuda detection and run without patching

* Equivalent timestamp processing as Nemo, and various nits/cleanup.

* Simplify durations config.

* Update training examples.

* chore: enable parralelism

* chore: performance optimization

* fix: formatting

* Doc and testing nits

* Use active mask from current step, and nits.

* Better pre-allocate.

* TDT has separate pad token and blank token.

* Regenerate lasr.

* Style checks and nits

* Nits, put back ctc loss test

* More standard model output.

* Style

* Remove compute_loss flag and allow monkey patching to tdt loss

* Update src/transformers/models/parakeet/modular_parakeet.py

Co-authored-by: eustlb <[email protected]>

* Address various comments.

* More compatible with Transformers forward/generate approach

* compile option for generation and decoder cache

* Cleaner, better conventions.

* Update with main.

* doc nits

* Imitate whisper for encoder outputs as input

* Address tests and nits.

* Inherit from GenerateMixIn for get_compiled_call

* Comment nit

* forward cleanup

* generate cleanup + separate generation file

* generate: add _supported_generation_modes

* automatic init of the loss

* modular cleanups

* use is_encoder_decoder

* timestamp processing fully from tokens + durations

* convertion script update

* test update

* make

* test update

* test update

* ensure correct loss computation

* kernel loss

* test loss integration

* push to hub pr

* integration tests to rely fully on transcripts

* udpate fixtures

* we don't need to monkey patch with numba anymore!

* fix pipeline usage

* nit

* fix usage

* Pass through tests and examples: improve kernel fallback, update with nvidia checkpoint, style checks.

* Update checkpoint

* Add TDT to mapping after merge.

* Fix lasr generate test.

* Output attention mask if labels provided for computing loss.

* Apply suggestion from @ArthurZucker

Co-authored-by: Arthur <[email protected]>

* Improve ParakeetTDTDecoderCache definition and usage.

* Remove tuple parsing.

* processor refactor

* Update conversion.

* Remove kernel to address in separate PR, modular nit.

* Lasr modular

* Add attention mask docstring.

---------

Co-authored-by: Hainan Xu <[email protected]>
Co-authored-by: Eric B <[email protected]>
Co-authored-by: Eric Bezzam <[email protected]>
Co-authored-by: eustlb <[email protected]>
Co-authored-by: Arthur <[email protected]>
Co-authored-by: Tarek Ziade <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.