Skip to content

Conversation

@marcvanzee
Copy link
Contributor

One small caveat: we renamed param in Module to params, so we may have to update this once we created a new pypi wheel, but I will make sure that this is the case.

Otherwise it seems good to go!

@mfuntowicz mfuntowicz merged commit 23703a5 into huggingface:jax-bert Sep 18, 2020
mfuntowicz pushed a commit that referenced this pull request Oct 5, 2020
* Rewrite Flax HuggingFace PR to Linen

* Some fixes

* Fix tests
mfuntowicz added a commit that referenced this pull request Oct 5, 2020
mfuntowicz pushed a commit that referenced this pull request Oct 19, 2020
* Rewrite Flax HuggingFace PR to Linen

* Some fixes

* Fix tests
mfuntowicz added a commit that referenced this pull request Oct 19, 2020
mfuntowicz pushed a commit that referenced this pull request Oct 19, 2020
* Rewrite Flax HuggingFace PR to Linen

* Some fixes

* Fix tests
mfuntowicz added a commit that referenced this pull request Oct 19, 2020
LysandreJik pushed a commit that referenced this pull request Oct 19, 2020
* WIP flax bert

* Initial commit Bert Jax/Flax implementation.

* Embeddings working and equivalent to PyTorch.

* Move embeddings in its own module BertEmbeddings

* Added jax.jit annotation on forward call

* BertEncoder on par with PyTorch ! :D

* Add BertPooler on par with PyTorch !!

* Working Jax+Flax implementation of BertModel with < 1e-5 differences on the last layer.

* Fix pooled output to take only the first token of the sequence.

* Refactoring to use BertConfig from transformers.

* Renamed FXBertModel to FlaxBertModel

* Model is now initialized in FlaxBertModel constructor and reused.

* WIP JaxPreTrainedModel

* Cleaning up the code of FlaxBertModel

* Added ability to load Flax model saved through save_pretrained()

* Added ability to convert Pytorch Bert model to FlaxBert

* FlaxBert can now load every Pytorch Bert model with on-the-fly conversion

* Fix hardcoded shape values in conversion scripts.

* Improve the way we handle LayerNorm conversion from PyTorch to Flax.

* Added positional embeddings as parameter of BertModel with default to np.arange.

* Let's roll FlaxRoberta !

* Fix missing position_ids parameters on predict for Bert

* Flax backend now supports batched inputs

Signed-off-by: Morgan Funtowicz <[email protected]>

* Make it possible to load msgpacked model on convert from pytorch in last resort.

Signed-off-by: Morgan Funtowicz <[email protected]>

* Moved save_pretrained to Jax base class along with more constructor parameters.

* Use specialized, model dependent conversion functio.

* Expose `is_flax_available` in file_utils.

* Added unittest for Flax models.

* Added run_tests_flax to the CI.

* Introduce FlaxAutoModel

* Added more unittests

* Flax model reference the _MODEL_ARCHIVE_MAP from PyTorch model.

* Addressing review comments.

* Expose seed in both Bert and Roberta

* Fix typo suggested by @stefan-it

Co-Authored-By: Stefan Schweter <[email protected]>

* Attempt to make style

* Attempt to make style in tests too

* Added jax & jaxlib to the flax optional dependencies.

* Attempt to fix flake8 warnings ...

* Redo black again and again

* When black and flake8 fight each other for a space ... 💥 💥 💥

* Try removing trailing comma to make both black and flake happy!

* Fix invalid is_<framework>_available call, thanks @LysandreJik 🎉

* Fix another invalid import in flax_roberta test

* Bump and pin flax release to 0.1.0.

* Make flake8 happy, remove unused jax import

* Change the type of the catch for msgpack.

* Remove unused import.

* Put seed as optional constructor parameter.

* trigger ci again

* Fix too much parameters in BertAttention.

* Formatting.

* Simplify Flax unittests to avoid machine crashes.

* Fix invalid number of arguments when raising issue for an unknown model.

* Address @bastings comment in PR, moving jax.jit decorated outside of __call__

* Fix incorrect path to require_flax/require_pytorch functions.

Signed-off-by: Morgan Funtowicz <[email protected]>

* Attempt to make style.

Signed-off-by: Morgan Funtowicz <[email protected]>

* Correct rebasing of circle-ci dependencies

Signed-off-by: Morgan Funtowicz <[email protected]>

* Fix import sorting.

Signed-off-by: Morgan Funtowicz <[email protected]>

* Fix unused imports.

Signed-off-by: Morgan Funtowicz <[email protected]>

* Again import sorting...

Signed-off-by: Morgan Funtowicz <[email protected]>

* Installing missing nlp dependency for flax unittests.

Signed-off-by: Morgan Funtowicz <[email protected]>

* Fix laoding of model for Flax implementations.

Signed-off-by: Morgan Funtowicz <[email protected]>

* jit the inner function call to make JAX-compatible

Signed-off-by: Morgan Funtowicz <[email protected]>

* Format !

Signed-off-by: Morgan Funtowicz <[email protected]>

* Flake one more time 🎶

Signed-off-by: Morgan Funtowicz <[email protected]>

* Rewrites BERT in Flax to the new Linen API (#7211)

* Rewrite Flax HuggingFace PR to Linen

* Some fixes

* Fix tests

* Fix CI with change of name of nlp (#7054)

* nlp -> datasets

* More nlp -> datasets

* Woopsie

* More nlp -> datasets

* One last

* Expose `is_flax_available` in file_utils.

* Added run_tests_flax to the CI.

* Attempt to make style

* trigger ci again

* Fix import sorting.

Signed-off-by: Morgan Funtowicz <[email protected]>

* Revert "Rewrites BERT in Flax to the new Linen API (#7211)"

This reverts commit 23703a5.

* Remove jnp.lax references

Signed-off-by: Morgan Funtowicz <[email protected]>

* Make style.

Signed-off-by: Morgan Funtowicz <[email protected]>

* Reintroduce Linen changes ...

Signed-off-by: Morgan Funtowicz <[email protected]>

* Make style.

Signed-off-by: Morgan Funtowicz <[email protected]>

* Use jax native's gelu function.

Signed-off-by: Morgan Funtowicz <[email protected]>

* Renaming BertModel to BertModule to highlight the fact this is the Flax Module object.

Signed-off-by: Morgan Funtowicz <[email protected]>

* Rewrite FlaxAutoModel test to not rely on pretrained_model_archive_map

Signed-off-by: Morgan Funtowicz <[email protected]>

* Remove unused variable in BertModule.

Signed-off-by: Morgan Funtowicz <[email protected]>

* Remove unused variable in BertModule again

Signed-off-by: Morgan Funtowicz <[email protected]>

* Attempt to have is_flax_available working again.

Signed-off-by: Morgan Funtowicz <[email protected]>

* Introduce JAX TensorType

Signed-off-by: Morgan Funtowicz <[email protected]>

* Improve ImportError message when trying to convert to various TensorType format.

Signed-off-by: Morgan Funtowicz <[email protected]>

* Makes Flax model jittable.

Signed-off-by: Morgan Funtowicz <[email protected]>

* Ensure flax models are jittable in unittests.

Signed-off-by: Morgan Funtowicz <[email protected]>

* Remove unused imports.

Signed-off-by: Morgan Funtowicz <[email protected]>

* Ensure jax imports are guarded behind is_flax_available.

Signed-off-by: Morgan Funtowicz <[email protected]>

* Make style.

Signed-off-by: Morgan Funtowicz <[email protected]>

* Make style again

Signed-off-by: Morgan Funtowicz <[email protected]>

* Make style again again

Signed-off-by: Morgan Funtowicz <[email protected]>

* Make style again again again

Signed-off-by: Morgan Funtowicz <[email protected]>

* Update src/transformers/file_utils.py

Co-authored-by: Marc van Zee <[email protected]>

* Bump flax to it's latest version

Co-authored-by: Marc van Zee <[email protected]>

* Bump jax version to at least 0.2.0

Signed-off-by: Morgan Funtowicz <[email protected]>

* Style.

Signed-off-by: Morgan Funtowicz <[email protected]>

* Update the unittest to use TensorType.JAX

Signed-off-by: Morgan Funtowicz <[email protected]>

* isort import in tests.

Signed-off-by: Morgan Funtowicz <[email protected]>

* Match new flax parameters name "params"

Signed-off-by: Morgan Funtowicz <[email protected]>

* Remove unused imports.

Signed-off-by: Morgan Funtowicz <[email protected]>

* Add flax models to transformers __init__

Signed-off-by: Morgan Funtowicz <[email protected]>

* Attempt to address all CI related comments.

Signed-off-by: Morgan Funtowicz <[email protected]>

* Correct circle.yml indent.

Signed-off-by: Morgan Funtowicz <[email protected]>

* Correct circle.yml indent (2)

Signed-off-by: Morgan Funtowicz <[email protected]>

* Remove coverage from flax tests

Signed-off-by: Morgan Funtowicz <[email protected]>

* Addressing many naming suggestions from comments

Signed-off-by: Morgan Funtowicz <[email protected]>

* Simplify for loop logic to interate over layers in FlaxBertLayerCollection

Signed-off-by: Morgan Funtowicz <[email protected]>

* use f-string syntax for formatting logs.

Signed-off-by: Morgan Funtowicz <[email protected]>

* Use config property from FlaxPreTrainedModel.

Signed-off-by: Morgan Funtowicz <[email protected]>

* use "cls_token" instead of "first_token" variable name.

Signed-off-by: Morgan Funtowicz <[email protected]>

* use "hidden_state" instead of "h" variable name.

Signed-off-by: Morgan Funtowicz <[email protected]>

* Correct class reference in docstring to link to Flax related modules.

Signed-off-by: Morgan Funtowicz <[email protected]>

* Added HF + Google Flax team copyright.

Signed-off-by: Morgan Funtowicz <[email protected]>

* Make Roberta independent from Bert

Signed-off-by: Morgan Funtowicz <[email protected]>

* Move activation functions to flax_utils.

Signed-off-by: Morgan Funtowicz <[email protected]>

* Move activation functions to flax_utils for bert.

Signed-off-by: Morgan Funtowicz <[email protected]>

* Added docstring for BERT

Signed-off-by: Morgan Funtowicz <[email protected]>

* Update import for Bert and Roberta tokenizers

Signed-off-by: Morgan Funtowicz <[email protected]>

* Make style.

Signed-off-by: Morgan Funtowicz <[email protected]>

* fix-copies

Signed-off-by: Morgan Funtowicz <[email protected]>

* Correct FlaxRobertaLayer to match PyTorch.

Signed-off-by: Morgan Funtowicz <[email protected]>

* Use the same store_artifact for flax unittest

Signed-off-by: Morgan Funtowicz <[email protected]>

* Style.

Signed-off-by: Morgan Funtowicz <[email protected]>

* Make sure gradient are disabled only locally for flax unittest using torch equivalence.

Signed-off-by: Morgan Funtowicz <[email protected]>

* Use relative imports

Signed-off-by: Morgan Funtowicz <[email protected]>

Co-authored-by: Stefan Schweter <[email protected]>
Co-authored-by: Marc van Zee <[email protected]>
Co-authored-by: Sylvain Gugger <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants