Skip to content

Conversation

@sanchit-gandhi
Copy link
Contributor

What does this PR do?

Fixes #15476

  • Adds an adapter to the Flax Wav2Vec2 model to reduce the time dimension of the extracted feature vectors beyond that of the standard Wav2Vec2 model. The encoder's output hidden states thus have a time context window that is more similar to that of a subword token instead of just a character.
  • Shape and values of Flax output logits match those of the PyTorch model.
  • Flax model uses all PyTorch model weights, including those of the adapter. Running the script in Add Adapter Weighs to Flax #15476 resolved to yield identical results (within 4e-2 threshold).

@HuggingFaceDocBuilder
Copy link

HuggingFaceDocBuilder commented Feb 4, 2022

The documentation is not available anymore as the PR was closed or merged.

@sanchit-gandhi sanchit-gandhi changed the title Flax wav2vec2 Add Wav2Vec2 Adapter Weights to Flax Feb 4, 2022
stas00 and others added 3 commits February 4, 2022 11:15
* Standardize instance segmentation models outputs

* Rename output

* Update src/transformers/modeling_outputs.py

Co-authored-by: NielsRogge <[email protected]>

* Add legacy argument to the config and model forward

* Update src/transformers/models/beit/modeling_beit.py

Co-authored-by: Lysandre Debut <[email protected]>

* Copy fix in Segformer

Co-authored-by: NielsRogge <[email protected]>
Co-authored-by: Lysandre Debut <[email protected]>
* [deepspeed docs] DeepSpeed ZeRO Inference

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <[email protected]>

* tweak

* deal with black

* extra cleanup, better comments

Co-authored-by: Sylvain Gugger <[email protected]>
Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool PR! Looks more or less good to me. Left some final comments and then I think we can merge :-)

patrickvonplaten and others added 9 commits February 7, 2022 15:35
* [torch_int_div] Correct true division in generation

* up

* up
* First draft

* Add conversion script

* Improve conversion script

* Improve docs and implement tests

* Define model output class

* Fix tests

* Fix more tests

* Add model to README

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <[email protected]>

* Apply more suggestions from code review

* Apply suggestions from code review

* Rename dims to hidden_sizes

* Fix equivalence test

* Rename gamma to gamma_parameter

* Clean up conversion script

* Add ConvNextFeatureExtractor

* Add corresponding tests

* Implement feature extractor correctly

* Make implementation cleaner

* Add ConvNextStem class

* Improve design

* Update design to also include encoder

* Fix gamma parameter

* Use sample docstrings

* Finish conversion, add center cropping

* Replace nielsr by facebook, make feature extractor tests smaller

* Fix integration test

Co-authored-by: Sylvain Gugger <[email protected]>
* Unused import

* Make `has_length()` torch-independent to use in callbacks

* Update src/transformers/trainer_utils.py

Co-authored-by: Sylvain Gugger <[email protected]>

Co-authored-by: Sylvain Gugger <[email protected]>
* Single-epoch run

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <[email protected]>

* Infinite dataset

* Trainer fix + distributed benchmark

* Benchmark fix

* unused import

* interleaved splits

* interleaved splits

* has_length util

* Move to research projects

* Leftover Sized checks

* Bump min version

* Unused import

* Revert trainer changes

Co-authored-by: Patrick von Platen <[email protected]>
* Wav2Vec2 models must either throw or deal with add_apater

Co-authored-by: Patrick von Platen <[email protected]>

* Add pre-add_adapter backwards compatibility

* Add pre-add_adapter backwards compatibility

* Fix issue in tests/test_modeling_wav2vec2.py

Co-authored-by: Patrick von Platen <[email protected]>

Co-authored-by: Patrick von Platen <[email protected]>
* add cross attn to outputs

* add cross attn to outputs for TFLED

* add undo padding

* remove unused import

* fix style

Co-authored-by: ydshieh <[email protected]>
@sanchit-gandhi sanchit-gandhi deleted the flax-wav2vec2 branch February 7, 2022 16:49
@sanchit-gandhi sanchit-gandhi restored the flax-wav2vec2 branch February 7, 2022 16:50
@sanchit-gandhi sanchit-gandhi reopened this Feb 7, 2022
sanchit-gandhi and others added 17 commits February 7, 2022 18:07
Co-authored-by: Patrick von Platen <[email protected]>
* fix outputs

* fix for CTC

* fix doc

* make style

Co-authored-by: ydshieh <[email protected]>
Co-authored-by: Patrick von Platen <[email protected]>
Co-authored-by: Patrick von Platen <[email protected]>
Co-authored-by: Patrick von Platen <[email protected]>
* 📝 add config section

* 📝 finish first draft

* 📝 add feature extractor and processor

* 🖍 apply feedback from review

* 📝 minor edits

* last review
* Change the way tracing happens, enabling dynamic axes out of the box

* Update the tests and modeling xlnet

* Add the non recoding of leaf modules to avoid recording more values for the methods to record than what will be seen at tracing time (which would otherwise desynchronize the recorded values and the values that need to be given to the proxies during tracing, causing errors).

* Comments and making tracing work for gpt-j and xlnet

* Refactore things related to num_choices (and batch_size, sequence_length)

* Update fx to work on PyTorch 1.10

* Postpone autowrap_function feature usage for later

* Add copyrights

* Remove unnecessary file

* Fix issue with add_new_model_like

* Apply suggestions
* electra is added to onnx supported model

* add google/electra-base-generator for test onnx module

Co-authored-by: Lewis Tunstall <[email protected]>
* use_cache = False for PT models if labels is passed

* Fix for BigBirdPegasusForConditionalGeneration

* add warning if users specify use_cache=True

* Use logger.warning instead of warnings.warn

Co-authored-by: ydshieh <[email protected]>
@patrickvonplaten
Copy link
Contributor

I think the commit history is messed up here - the best is usually to just reopen a new PR and to just extract your changes from this PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Add Adapter Weighs to Flax