Parakeet tdt#44171
Conversation
Implement Token-and-Duration Transducer (TDT) decoding for Parakeet models, extending the existing CTC-only support. This adds ParakeetForTDT with greedy TDT decoding in generate(), per-token timestamp generation, and full integration with AutoModelForTDT, processors, and ASR pipeline.
6c98cb8 to
f2b4938
Compare
There was a problem hiding this comment.
@lmaksym thank you putting together the PRs cleanly! I pushed a few changes for adapting to Transformers convention and added integration tests to compare with the original model from NeMo.
@hainan-xv and @nithinraok, your input could be useful for the TDT decoding, and also the loss computation.
- Use -100 label padding for training (HF convention) - Fix timestamp recording in inner blank-seeking loop - Add max_symbols_per_step guard matching NeMo - Clean up decoding loop - Add TDT training example to docs - Use setUpClass for TDT integration tests
7f70c24 to
760b4b6
Compare
hainan-xv
left a comment
There was a problem hiding this comment.
Left a comment on the loss computation part.
ebezzam
left a comment
There was a problem hiding this comment.
@lmaksym thanks for porting the TDT loss! it's nice (1) to not have to depend on torchaudio and (2) to make the TDT loss available in Transformers!
It is functional with this example (single GPU): https://gist.github.com/ebezzam/6382bdabfc64bb2541ca9f77deb7678d#file-tdt_training_snippet-py
But quite slow...
I wonder if there is a custom gradient computation in NeMo? As I noticed in the paper (Section 3.1), they say "We derive an analytical solution for the gradient of the TDT loss, since automatic differentiation for transducer loss is highly inefficient."
FYI I can test/fix on my side for multi-GPU compatibility.
I'll look into that |
ebezzam
left a comment
There was a problem hiding this comment.
@ArthurZucker thanks for your comments! I've made changes to the cache object and some refactoring in the processor.
For the hub kernel, I can see with Cyril
| from ...utils import ModelOutput | ||
|
|
||
|
|
||
| class ParakeetTDTDecoderCache: |
There was a problem hiding this comment.
Moved cache object to generation
|
|
||
|
|
||
| class ParakeetTDTDecoderCache: | ||
| def __init__(self, config): |
There was a problem hiding this comment.
config passed so we don't need to pass lstm module to update for lazy initialization
| # Get cached hidden/cell states if available, otherwise initialize with ParakeetTDTDecoderCache | ||
| if cache is not None: | ||
| was_initialized = cache.is_initialized | ||
| if not was_initialized: | ||
| cache.lazy_initialization(embeddings) | ||
| hidden_cell_states = (cache.hidden_state, cache.cell_state) | ||
| else: | ||
| hidden_cell_states = None |
There was a problem hiding this comment.
Related to your comment, I tried to make it clearer how hidden/cell state get initialized. But still a bit clunky...
In case links not redirecting well:
(Arthur)
If not cache.is_initialized but still passed why are we not passing?
(Eric)
I think the confusing part is that we were actually relying on self.lstm to initialize hidden_state and cell_state (rather than ParakeetTDTDecoderCache.lazy_initialization) and then passing them to ParakeetTDTDecoderCache to set them. I'll try coming up with something that's clearer
|
|
||
| if cache is not None: | ||
| mask = ~blank_mask if was_initialized else None | ||
| cache.update(decoder_output, hidden_state, cell_state, mask=mask) |
There was a problem hiding this comment.
LSTM module no longer passed
| ) | ||
|
|
||
| if use_decoder_cache and decoder_cache is None: | ||
| decoder_cache = ParakeetTDTDecoderCache(self.config) |
There was a problem hiding this comment.
passing config to help with initialization (instead of using LSTM module)
| stream = DecodeStream(skip_special_tokens=True) | ||
| timestamp_dict = [] | ||
| for i, token_id in enumerate(batch_ids): | ||
| if int(token_id) in skip_ids: | ||
| continue | ||
| chunk = stream.step(self.tokenizer._tokenizer, int(token_id)) | ||
| if chunk is not None: | ||
| timestamp_dict.append( | ||
| { | ||
| "token": chunk, | ||
| "start": int(batch_timestamps[i]), | ||
| "end": int(batch_timestamps[i] + batch_durations[i]), | ||
| } | ||
| ) |
There was a problem hiding this comment.
Refactored to use DecodeStream but I think we still need the inner loop to associate the individual token to its timestamps?
re this comment
(Arthur) all of this can and potetntially should be done in a single for loop no?
(Eric)
hmm from what I understand of DecodeStream, its step() method must be called per-token to get per-token decoded text for timestamp pairing. Passing a list merges the output into one chunk, losing the 1:1 token-to-timestamp mapping.
There was a problem hiding this comment.
yeah , the stream just simplifies the non valid token handling 😉
ArthurZucker
left a comment
There was a problem hiding this comment.
On nit, let's avoid modifying the general code for conversion
There was a problem hiding this comment.
Yeah, but I don't think the PB is here, maybe more in ParakeetConverter(model_files["tokenizer_model_file"]).converted() ?
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
run-slow: lasr, parakeet |
|
This comment contains models: ["models/lasr", "models/parakeet"] |
|
[For maintainers] Suggested jobs to run (before merge) run-slow: auto, encodec, lasr, parakeet |
|
View the CircleCI Test Summary for this PR: |
* parakeet tdt intergration * Add TDT decoder support for Parakeet ASR models Implement Token-and-Duration Transducer (TDT) decoding for Parakeet models, extending the existing CTC-only support. This adds ParakeetForTDT with greedy TDT decoding in generate(), per-token timestamp generation, and full integration with AutoModelForTDT, processors, and ASR pipeline. * Add expected outputs for TDT, small fixes. * Separate CTC and TDT generate outputs. * Work with auto device, better init, * Test timestamps and expose token duration. * Add reproducer link. * fix: align TDT training and decoding with NeMo implementation - Use -100 label padding for training (HF convention) - Fix timestamp recording in inner blank-seeking loop - Add max_symbols_per_step guard matching NeMo - Clean up decoding loop - Add TDT training example to docs - Use setUpClass for TDT integration tests * revert: restore lasr generated files to original state * warn: torchaudio rnnt_loss does not train duration head * Relax timestamp test, and test nits. * feat: TDT training * chore: for cuda detection and run without patching * Equivalent timestamp processing as Nemo, and various nits/cleanup. * Simplify durations config. * Update training examples. * chore: enable parralelism * chore: performance optimization * fix: formatting * Doc and testing nits * Use active mask from current step, and nits. * Better pre-allocate. * TDT has separate pad token and blank token. * Regenerate lasr. * Style checks and nits * Nits, put back ctc loss test * More standard model output. * Style * Remove compute_loss flag and allow monkey patching to tdt loss * Update src/transformers/models/parakeet/modular_parakeet.py Co-authored-by: eustlb <[email protected]> * Address various comments. * More compatible with Transformers forward/generate approach * compile option for generation and decoder cache * Cleaner, better conventions. * Update with main. * doc nits * Imitate whisper for encoder outputs as input * Address tests and nits. * Inherit from GenerateMixIn for get_compiled_call * Comment nit * forward cleanup * generate cleanup + separate generation file * generate: add _supported_generation_modes * automatic init of the loss * modular cleanups * use is_encoder_decoder * timestamp processing fully from tokens + durations * convertion script update * test update * make * test update * test update * ensure correct loss computation * kernel loss * test loss integration * push to hub pr * integration tests to rely fully on transcripts * udpate fixtures * we don't need to monkey patch with numba anymore! * fix pipeline usage * nit * fix usage * Pass through tests and examples: improve kernel fallback, update with nvidia checkpoint, style checks. * Update checkpoint * Add TDT to mapping after merge. * Fix lasr generate test. * Output attention mask if labels provided for computing loss. * Apply suggestion from @ArthurZucker Co-authored-by: Arthur <[email protected]> * Improve ParakeetTDTDecoderCache definition and usage. * Remove tuple parsing. * processor refactor * Update conversion. * Remove kernel to address in separate PR, modular nit. * Lasr modular * Add attention mask docstring. --------- Co-authored-by: Hainan Xu <[email protected]> Co-authored-by: Eric B <[email protected]> Co-authored-by: Eric Bezzam <[email protected]> Co-authored-by: eustlb <[email protected]> Co-authored-by: Arthur <[email protected]> Co-authored-by: Tarek Ziade <[email protected]>
* parakeet tdt intergration * Add TDT decoder support for Parakeet ASR models Implement Token-and-Duration Transducer (TDT) decoding for Parakeet models, extending the existing CTC-only support. This adds ParakeetForTDT with greedy TDT decoding in generate(), per-token timestamp generation, and full integration with AutoModelForTDT, processors, and ASR pipeline. * Add expected outputs for TDT, small fixes. * Separate CTC and TDT generate outputs. * Work with auto device, better init, * Test timestamps and expose token duration. * Add reproducer link. * fix: align TDT training and decoding with NeMo implementation - Use -100 label padding for training (HF convention) - Fix timestamp recording in inner blank-seeking loop - Add max_symbols_per_step guard matching NeMo - Clean up decoding loop - Add TDT training example to docs - Use setUpClass for TDT integration tests * revert: restore lasr generated files to original state * warn: torchaudio rnnt_loss does not train duration head * Relax timestamp test, and test nits. * feat: TDT training * chore: for cuda detection and run without patching * Equivalent timestamp processing as Nemo, and various nits/cleanup. * Simplify durations config. * Update training examples. * chore: enable parralelism * chore: performance optimization * fix: formatting * Doc and testing nits * Use active mask from current step, and nits. * Better pre-allocate. * TDT has separate pad token and blank token. * Regenerate lasr. * Style checks and nits * Nits, put back ctc loss test * More standard model output. * Style * Remove compute_loss flag and allow monkey patching to tdt loss * Update src/transformers/models/parakeet/modular_parakeet.py Co-authored-by: eustlb <[email protected]> * Address various comments. * More compatible with Transformers forward/generate approach * compile option for generation and decoder cache * Cleaner, better conventions. * Update with main. * doc nits * Imitate whisper for encoder outputs as input * Address tests and nits. * Inherit from GenerateMixIn for get_compiled_call * Comment nit * forward cleanup * generate cleanup + separate generation file * generate: add _supported_generation_modes * automatic init of the loss * modular cleanups * use is_encoder_decoder * timestamp processing fully from tokens + durations * convertion script update * test update * make * test update * test update * ensure correct loss computation * kernel loss * test loss integration * push to hub pr * integration tests to rely fully on transcripts * udpate fixtures * we don't need to monkey patch with numba anymore! * fix pipeline usage * nit * fix usage * Pass through tests and examples: improve kernel fallback, update with nvidia checkpoint, style checks. * Update checkpoint * Add TDT to mapping after merge. * Fix lasr generate test. * Output attention mask if labels provided for computing loss. * Apply suggestion from @ArthurZucker Co-authored-by: Arthur <[email protected]> * Improve ParakeetTDTDecoderCache definition and usage. * Remove tuple parsing. * processor refactor * Update conversion. * Remove kernel to address in separate PR, modular nit. * Lasr modular * Add attention mask docstring. --------- Co-authored-by: Hainan Xu <[email protected]> Co-authored-by: Eric B <[email protected]> Co-authored-by: Eric Bezzam <[email protected]> Co-authored-by: eustlb <[email protected]> Co-authored-by: Arthur <[email protected]> Co-authored-by: Tarek Ziade <[email protected]>
What does this PR do?
This PR adds TDT decoder support for Parakeet ASR models, extending the existing CTC-only implementation.
It incorporates the initial TDT integration work from #41545 by @hainan-xv (was not merged) and and addresses all review feedback from both #41545 and #43357.
Changes
ParakeetForTDTmodel with greedy TDT decoding ingenerate()ParakeetTDTDecoder(LSTM prediction network) andParakeetTDTJointNetworkasnn.Modulesubclassesreturn_timestamps=TrueAutoModelForTDTauto class with pipeline, processor, and tokenizer integrationParakeetTDTConfigmatching the CTC pattern (no nested decoder/joint configs)ParakeetPreTrainedModelbase between CTC and TDT (no separate TDT base class)Validation
make check-repopassesBefore submitting
Pull Request section?
documentation guidelines, and
here are tips on formatting docstrings.
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@ebezzam and @hainan-xv please review
-->