-
Notifications
You must be signed in to change notification settings - Fork 31.5k
Rewrites BERT in Flax to the new Linen API #7211
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
mfuntowicz
pushed a commit
that referenced
this pull request
Oct 5, 2020
* Rewrite Flax HuggingFace PR to Linen * Some fixes * Fix tests
mfuntowicz
added a commit
that referenced
this pull request
Oct 5, 2020
This reverts commit 23703a5.
mfuntowicz
pushed a commit
that referenced
this pull request
Oct 19, 2020
* Rewrite Flax HuggingFace PR to Linen * Some fixes * Fix tests
mfuntowicz
added a commit
that referenced
this pull request
Oct 19, 2020
This reverts commit 23703a5.
mfuntowicz
pushed a commit
that referenced
this pull request
Oct 19, 2020
* Rewrite Flax HuggingFace PR to Linen * Some fixes * Fix tests
mfuntowicz
added a commit
that referenced
this pull request
Oct 19, 2020
This reverts commit 23703a5.
LysandreJik
pushed a commit
that referenced
this pull request
Oct 19, 2020
* WIP flax bert * Initial commit Bert Jax/Flax implementation. * Embeddings working and equivalent to PyTorch. * Move embeddings in its own module BertEmbeddings * Added jax.jit annotation on forward call * BertEncoder on par with PyTorch ! :D * Add BertPooler on par with PyTorch !! * Working Jax+Flax implementation of BertModel with < 1e-5 differences on the last layer. * Fix pooled output to take only the first token of the sequence. * Refactoring to use BertConfig from transformers. * Renamed FXBertModel to FlaxBertModel * Model is now initialized in FlaxBertModel constructor and reused. * WIP JaxPreTrainedModel * Cleaning up the code of FlaxBertModel * Added ability to load Flax model saved through save_pretrained() * Added ability to convert Pytorch Bert model to FlaxBert * FlaxBert can now load every Pytorch Bert model with on-the-fly conversion * Fix hardcoded shape values in conversion scripts. * Improve the way we handle LayerNorm conversion from PyTorch to Flax. * Added positional embeddings as parameter of BertModel with default to np.arange. * Let's roll FlaxRoberta ! * Fix missing position_ids parameters on predict for Bert * Flax backend now supports batched inputs Signed-off-by: Morgan Funtowicz <[email protected]> * Make it possible to load msgpacked model on convert from pytorch in last resort. Signed-off-by: Morgan Funtowicz <[email protected]> * Moved save_pretrained to Jax base class along with more constructor parameters. * Use specialized, model dependent conversion functio. * Expose `is_flax_available` in file_utils. * Added unittest for Flax models. * Added run_tests_flax to the CI. * Introduce FlaxAutoModel * Added more unittests * Flax model reference the _MODEL_ARCHIVE_MAP from PyTorch model. * Addressing review comments. * Expose seed in both Bert and Roberta * Fix typo suggested by @stefan-it Co-Authored-By: Stefan Schweter <[email protected]> * Attempt to make style * Attempt to make style in tests too * Added jax & jaxlib to the flax optional dependencies. * Attempt to fix flake8 warnings ... * Redo black again and again * When black and flake8 fight each other for a space ... 💥 💥 💥 * Try removing trailing comma to make both black and flake happy! * Fix invalid is_<framework>_available call, thanks @LysandreJik 🎉 * Fix another invalid import in flax_roberta test * Bump and pin flax release to 0.1.0. * Make flake8 happy, remove unused jax import * Change the type of the catch for msgpack. * Remove unused import. * Put seed as optional constructor parameter. * trigger ci again * Fix too much parameters in BertAttention. * Formatting. * Simplify Flax unittests to avoid machine crashes. * Fix invalid number of arguments when raising issue for an unknown model. * Address @bastings comment in PR, moving jax.jit decorated outside of __call__ * Fix incorrect path to require_flax/require_pytorch functions. Signed-off-by: Morgan Funtowicz <[email protected]> * Attempt to make style. Signed-off-by: Morgan Funtowicz <[email protected]> * Correct rebasing of circle-ci dependencies Signed-off-by: Morgan Funtowicz <[email protected]> * Fix import sorting. Signed-off-by: Morgan Funtowicz <[email protected]> * Fix unused imports. Signed-off-by: Morgan Funtowicz <[email protected]> * Again import sorting... Signed-off-by: Morgan Funtowicz <[email protected]> * Installing missing nlp dependency for flax unittests. Signed-off-by: Morgan Funtowicz <[email protected]> * Fix laoding of model for Flax implementations. Signed-off-by: Morgan Funtowicz <[email protected]> * jit the inner function call to make JAX-compatible Signed-off-by: Morgan Funtowicz <[email protected]> * Format ! Signed-off-by: Morgan Funtowicz <[email protected]> * Flake one more time 🎶 Signed-off-by: Morgan Funtowicz <[email protected]> * Rewrites BERT in Flax to the new Linen API (#7211) * Rewrite Flax HuggingFace PR to Linen * Some fixes * Fix tests * Fix CI with change of name of nlp (#7054) * nlp -> datasets * More nlp -> datasets * Woopsie * More nlp -> datasets * One last * Expose `is_flax_available` in file_utils. * Added run_tests_flax to the CI. * Attempt to make style * trigger ci again * Fix import sorting. Signed-off-by: Morgan Funtowicz <[email protected]> * Revert "Rewrites BERT in Flax to the new Linen API (#7211)" This reverts commit 23703a5. * Remove jnp.lax references Signed-off-by: Morgan Funtowicz <[email protected]> * Make style. Signed-off-by: Morgan Funtowicz <[email protected]> * Reintroduce Linen changes ... Signed-off-by: Morgan Funtowicz <[email protected]> * Make style. Signed-off-by: Morgan Funtowicz <[email protected]> * Use jax native's gelu function. Signed-off-by: Morgan Funtowicz <[email protected]> * Renaming BertModel to BertModule to highlight the fact this is the Flax Module object. Signed-off-by: Morgan Funtowicz <[email protected]> * Rewrite FlaxAutoModel test to not rely on pretrained_model_archive_map Signed-off-by: Morgan Funtowicz <[email protected]> * Remove unused variable in BertModule. Signed-off-by: Morgan Funtowicz <[email protected]> * Remove unused variable in BertModule again Signed-off-by: Morgan Funtowicz <[email protected]> * Attempt to have is_flax_available working again. Signed-off-by: Morgan Funtowicz <[email protected]> * Introduce JAX TensorType Signed-off-by: Morgan Funtowicz <[email protected]> * Improve ImportError message when trying to convert to various TensorType format. Signed-off-by: Morgan Funtowicz <[email protected]> * Makes Flax model jittable. Signed-off-by: Morgan Funtowicz <[email protected]> * Ensure flax models are jittable in unittests. Signed-off-by: Morgan Funtowicz <[email protected]> * Remove unused imports. Signed-off-by: Morgan Funtowicz <[email protected]> * Ensure jax imports are guarded behind is_flax_available. Signed-off-by: Morgan Funtowicz <[email protected]> * Make style. Signed-off-by: Morgan Funtowicz <[email protected]> * Make style again Signed-off-by: Morgan Funtowicz <[email protected]> * Make style again again Signed-off-by: Morgan Funtowicz <[email protected]> * Make style again again again Signed-off-by: Morgan Funtowicz <[email protected]> * Update src/transformers/file_utils.py Co-authored-by: Marc van Zee <[email protected]> * Bump flax to it's latest version Co-authored-by: Marc van Zee <[email protected]> * Bump jax version to at least 0.2.0 Signed-off-by: Morgan Funtowicz <[email protected]> * Style. Signed-off-by: Morgan Funtowicz <[email protected]> * Update the unittest to use TensorType.JAX Signed-off-by: Morgan Funtowicz <[email protected]> * isort import in tests. Signed-off-by: Morgan Funtowicz <[email protected]> * Match new flax parameters name "params" Signed-off-by: Morgan Funtowicz <[email protected]> * Remove unused imports. Signed-off-by: Morgan Funtowicz <[email protected]> * Add flax models to transformers __init__ Signed-off-by: Morgan Funtowicz <[email protected]> * Attempt to address all CI related comments. Signed-off-by: Morgan Funtowicz <[email protected]> * Correct circle.yml indent. Signed-off-by: Morgan Funtowicz <[email protected]> * Correct circle.yml indent (2) Signed-off-by: Morgan Funtowicz <[email protected]> * Remove coverage from flax tests Signed-off-by: Morgan Funtowicz <[email protected]> * Addressing many naming suggestions from comments Signed-off-by: Morgan Funtowicz <[email protected]> * Simplify for loop logic to interate over layers in FlaxBertLayerCollection Signed-off-by: Morgan Funtowicz <[email protected]> * use f-string syntax for formatting logs. Signed-off-by: Morgan Funtowicz <[email protected]> * Use config property from FlaxPreTrainedModel. Signed-off-by: Morgan Funtowicz <[email protected]> * use "cls_token" instead of "first_token" variable name. Signed-off-by: Morgan Funtowicz <[email protected]> * use "hidden_state" instead of "h" variable name. Signed-off-by: Morgan Funtowicz <[email protected]> * Correct class reference in docstring to link to Flax related modules. Signed-off-by: Morgan Funtowicz <[email protected]> * Added HF + Google Flax team copyright. Signed-off-by: Morgan Funtowicz <[email protected]> * Make Roberta independent from Bert Signed-off-by: Morgan Funtowicz <[email protected]> * Move activation functions to flax_utils. Signed-off-by: Morgan Funtowicz <[email protected]> * Move activation functions to flax_utils for bert. Signed-off-by: Morgan Funtowicz <[email protected]> * Added docstring for BERT Signed-off-by: Morgan Funtowicz <[email protected]> * Update import for Bert and Roberta tokenizers Signed-off-by: Morgan Funtowicz <[email protected]> * Make style. Signed-off-by: Morgan Funtowicz <[email protected]> * fix-copies Signed-off-by: Morgan Funtowicz <[email protected]> * Correct FlaxRobertaLayer to match PyTorch. Signed-off-by: Morgan Funtowicz <[email protected]> * Use the same store_artifact for flax unittest Signed-off-by: Morgan Funtowicz <[email protected]> * Style. Signed-off-by: Morgan Funtowicz <[email protected]> * Make sure gradient are disabled only locally for flax unittest using torch equivalence. Signed-off-by: Morgan Funtowicz <[email protected]> * Use relative imports Signed-off-by: Morgan Funtowicz <[email protected]> Co-authored-by: Stefan Schweter <[email protected]> Co-authored-by: Marc van Zee <[email protected]> Co-authored-by: Sylvain Gugger <[email protected]>
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
One small caveat: we renamed
paramin Module toparams, so we may have to update this once we created a new pypi wheel, but I will make sure that this is the case.Otherwise it seems good to go!