Skip to content

Conversation

@mfuntowicz
Copy link
Member

@mfuntowicz mfuntowicz commented Apr 9, 2020

This Pull Request attempts to bring support for Flax framework as part of transformers.

Main focus as been put on providing BERT-like models, principally by making it possible to load PyTorch checkpoints and doing the necessary conversions (few) directly on the fly. Supports also providing a msgpack formatted file from Flax.

save_pretrained will save the model through msgpack format to avoid dependency on torch inside Jax code.

Targeted models:

  • Bert
  • RoBERTa
  • DistilBERT
  • DistilRoBERTa

If not too hard

  • CamemBERT

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I understand Flax is a neural-network library built on top of Jax, which brings automatic differentiation for python/numpy operations.

Do you mind walking me through why we use both here, and, more precisely, why we have a JaxPreTrainedModel as a parent of FlaxXXX models? Is the JaxPreTrainedModel's purpose to be able to accommodate models built on top of it from different jax-based libraries, or is it that it only depends on Jax operations so there is no need for it to be Flax-based?

Other than that, it looks like a great first approach! I guess we would need to add a few features as we go on - but it's impressive that you got it working so fast, with the same output between PT/TF/Flax!


# Models are loaded from Pytorch checkpoints
BERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
"bert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-pytorch_model.bin",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really nice to be able to load models from their PyTorch checkpoints


def __init__(self, config: BertConfig, state: dict, **kwargs):
self.config = config
self.key = PRNGKey(0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this used for?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's mainly usage is related to all the stochastic operations, such as Dropout that I didn't put into the model for now, but might be pushed soon.

@mfuntowicz
Copy link
Member Author

As you said, Jax is a library that interact with numpy to provide additional features: autodiff, auto-vectorization (vmap) and auto-parallelization (pmap).

Jax is essentially stateless, which is reflected here through the function to differentiate (the model) doesn't holds the parameters. They have to be referenced somewhere else and feed somehow.

JaxPreTrainedModel is introduced here mainly to handle the serialization of such model and provide conversion. Also, one specificity of Jax is many different Neural Network library are currently being implemented on top of it:

In that aspect, @madisonmay is currently working on a Haiku Bert integration in transformers. My hope it to be able to share as many things as possible between the two implementations (but can't be sure for now)

@LysandreJik
Copy link
Member

Alright, that makes sense. Thanks for the explanation.

@mfuntowicz mfuntowicz marked this pull request as ready for review April 17, 2020 13:25
Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool, it's starting to look really good! I just have a few questions regarding conversion/model architecture.

Copy link
Member

@thomwolf thomwolf left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok looks really great (Flax is quite pleasant to read)

Added a couple of remarks and questions. Happy to discuss them.

USE_TF = os.environ.get("USE_TF", "AUTO").upper()
USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
if USE_TORCH in ("1", "ON", "YES", "AUTO") and USE_TF not in ("1", "ON", "YES"):
if USE_TORCH in ENV_VARS_TRUE_VALUES and USE_TF not in ("1", "ON", "YES"):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not the same for the end of the line? (USE_TF)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, why is _torch_available dependent of USE_TF?

USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()

if USE_TF in ("1", "ON", "YES", "AUTO") and USE_TORCH not in ("1", "ON", "YES"):
if USE_TF in ENV_VARS_TRUE_VALUES and USE_TORCH not in ("1", "ON", "YES"):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here as well, why not the same for the end of the line?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, why do we test USE_TORCH for _tf_available?

config (:class:`~transformers.PretrainedConfig`):
The model class to instantiate is selected based on the configuration class:
- isInstance of `roberta` configuration class: :class:`~transformers.RobertaModel` (RoBERTa model)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FlaxRobertaModel

The model class to instantiate is selected based on the configuration class:
- isInstance of `roberta` configuration class: :class:`~transformers.RobertaModel` (RoBERTa model)
- isInstance of `bert` configuration class: :class:`~transformers.BertModel` (Bert model)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FlaxBertModel

self._module = module

# Those are public as their type is generic to every derived classes.
self.key = PRNGKey(0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should we have the PRNGKey seed exposed as a model args so that users can have (and control) several weight initialization seeds.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1. I think it can be really useful to be able to configure the random seed.

super().__init__(config, model_def, state)

@property
def module(self) -> BertModel:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BertModel or FlaxBertModel?


@property
def config(self) -> BertConfig:
return self._config
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't these module and config properties be in the base class?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that was the initial impl, but in that case we cannot have the correct return type for the config and model which are model dependant.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit) We could return a PreTrainedConfig for the configuration, which would be complete enough imo. This comment does not apply to the module though.

Comment on lines 23 to 20
@property
def config(self) -> RobertaConfig:
return self._config
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, shouldn't this be in the base class?

Comment on lines 30 to 390
if token_type_ids is None:
token_type_ids = np.ones_like(input_ids)

if position_ids is None:
position_ids = np.arange(
self.config.pad_token_id + 1,
np.atleast_2d(input_ids).shape[-1] + self.config.pad_token_id + 1
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in the FlaxBertModel these parameters are created ouside of the jited predict().
Any reason it's different here?
Should we standardize on one practice to make it easier to read for the user?

config_class = None
pretrained_model_archive_map = {}
base_model_prefix = ""
model_class = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is model_class used for?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

model_class is the underlying flax.nn.Module class which is required by Flax/msgpack to allocated all the buffers.

class BertIntermediate(nn.Module):

def apply(self, hidden_state, output_size: int):
# TODO: Had ACT2FN reference to change activation function
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: typo?

vocab_size: int, hidden_size: int, type_vocab_size: int, max_length: int):

# Embed
w_emb = BertEmbedding(jnp.atleast_2d(input_ids.astype('i4')), vocab_size, hidden_size, name="word_embeddings")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe split these in 2 lines to make it less wide?

self._module = module

# Those are public as their type is generic to every derived classes.
self.key = PRNGKey(0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1. I think it can be really useful to be able to configure the random seed.

def config(self) -> RobertaConfig:
return self._config

def __call__(self, input_ids, token_type_ids=None, position_ids=None, attention_mask=None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How often is this called? RIght now it looks like you @jit the predict graph each time the function is called. Is that the intention? Can you define&compile the predict function elsewhere?

return self._config

def __call__(self, input_ids, token_type_ids=None, position_ids=None, attention_mask=None):
@jax.jit
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like you should move this outside of the call.
So everything that you want to have @jax.jit'ed , so that when it is called multiple times it is cached.
Right now this gets compiled on each call to call, which should cause a lot of compilation overhead.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah you definitely don't want to use jit inside a function like this, unless perhaps it's only called once anyway.

@stale
Copy link

stale bot commented Jul 29, 2020

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

@stale stale bot added the wontfix label Jul 29, 2020
@LysandreJik
Copy link
Member

Unstale

@codecov
Copy link

codecov bot commented Sep 7, 2020

Codecov Report

Merging #3722 into master will increase coverage by 2.55%.
The diff coverage is 90.11%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master    #3722      +/-   ##
==========================================
+ Coverage   78.32%   80.88%   +2.55%     
==========================================
  Files         187      165      -22     
  Lines       37162    30383    -6779     
==========================================
- Hits        29107    24575    -4532     
+ Misses       8055     5808    -2247     
Impacted Files Coverage Δ
src/transformers/__init__.py 99.30% <ø> (-0.11%) ⬇️
src/transformers/modeling_flax_auto.py 60.86% <60.86%> (ø)
src/transformers/modeling_flax_utils.py 83.60% <83.60%> (ø)
src/transformers/file_utils.py 82.92% <92.85%> (-0.05%) ⬇️
src/transformers/modeling_flax_roberta.py 94.11% <94.11%> (ø)
src/transformers/modeling_flax_bert.py 96.50% <96.50%> (ø)
src/transformers/testing_utils.py 67.66% <100.00%> (+0.38%) ⬆️
src/transformers/modeling_tf_mobilebert.py 24.55% <0.00%> (-72.40%) ⬇️
src/transformers/modeling_tf_flaubert.py 24.53% <0.00%> (-65.14%) ⬇️
src/transformers/trainer.py 55.11% <0.00%> (-9.71%) ⬇️
... and 156 more

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 60de910...c0d1c81. Read the comment docs.

return jax.lax.tanh(out)


class BertModel(nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comparing your implementation to our lm1b example (https://github.com/google/flax/blob/master/examples/lm1b/models.py), it seems your code contains significantly more nn.Module abstractions. Why did you decide to create them? Is this to ensure the translation from a Pytorch model was easier?

I personally prefer fewer abstraction, since I think it would make the code more concise and easier to digest, but i'd be interested in hearing what your thoughts are!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As you guessed, I choose to follow what is already widely adopted in the library regarding modules fragmentation.

This is something our users are welcoming because they like how easy it is to tweak one module. Also it makes almost a 1-1 matchs with both PyTorch & TensorFlow implementations, so it might be easier for users who would like to give it a try to do the move.

Does it make sense?

BERT implementation using JAX/Flax as backend
"""

model_class = BertModel
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I personally think the names FlaxBertModel and BertModel may be a bit confusing, since in fact it seems BertModel is a nn.Module and in that sense more like a Flax model than the other one. Perhaps you could rename BertModel to BertModule, to clarify this is in fact the Flax module, and the other one is a wrapper around it?

@mfuntowicz
Copy link
Member Author

cc @levskaya

from jax.config import config
# TODO(marcvanzee): Flax Linen requires JAX omnistaging. Remove this
# once JAX enables it by default.
config.enable_omnistaging()

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we cut a new release of flax and pin to newer jax/flax versions this is no longer necessary as it's now (recently) the default.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, I'll add a suggestion.

return self._config

def __call__(self, input_ids, token_type_ids=None, position_ids=None, attention_mask=None):
@jax.jit

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah you definitely don't want to use jit inside a function like this, unless perhaps it's only called once anyway.

Copy link
Contributor

@marcvanzee marcvanzee left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great!!

from jax.config import config
# TODO(marcvanzee): Flax Linen requires JAX omnistaging. Remove this
# once JAX enables it by default.
config.enable_omnistaging()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, I'll add a suggestion.

Signed-off-by: Morgan Funtowicz <[email protected]>
Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, thanks for iterating @mfuntowicz!

@LysandreJik LysandreJik merged commit 8f8f8d9 into master Oct 19, 2020
@LysandreJik LysandreJik deleted the jax-bert branch October 19, 2020 13:55
@stas00
Copy link
Contributor

stas00 commented Oct 19, 2020

It looks like a file is missing:

$ make fixup
[...]
Checking all models are properly tested.
Traceback (most recent call last):
  File "utils/check_repo.py", line 327, in <module>
    check_repo_quality()
  File "utils/check_repo.py", line 321, in check_repo_quality
    check_all_models_are_tested()
  File "utils/check_repo.py", line 212, in check_all_models_are_tested
    new_failures = check_models_are_tested(module, test_file)
  File "utils/check_repo.py", line 182, in check_models_are_tested
    tested_models = find_tested_models(test_file)
  File "utils/check_repo.py", line 163, in find_tested_models
    with open(os.path.join(PATH_TO_TESTS, test_file)) as f:
FileNotFoundError: [Errno 2] No such file or directory: 'tests/test_modeling_flax_utils.py'
Makefile:25: recipe for target 'extra_quality_checks' failed
make: *** [extra_quality_checks] Error 1

Shouldn't the CI have caught this?

@sgugger
Copy link
Collaborator

sgugger commented Oct 19, 2020

Looks like a problem in make fixup, make quality runs fine (and that's what the CI runs).

@stas00
Copy link
Contributor

stas00 commented Oct 19, 2020

Nope, both run the same sub-target: extra_quality_checks

$ make quality
[...]
python utils/check_copies.py
python utils/check_dummies.py
python utils/check_repo.py
2020-10-19 12:11:26.345843: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcudart.so.10.1
Checking all models are properly tested.
Traceback (most recent call last):
  File "utils/check_repo.py", line 327, in <module>
    check_repo_quality()
  File "utils/check_repo.py", line 321, in check_repo_quality
    check_all_models_are_tested()
  File "utils/check_repo.py", line 212, in check_all_models_are_tested
    new_failures = check_models_are_tested(module, test_file)
  File "utils/check_repo.py", line 182, in check_models_are_tested
    tested_models = find_tested_models(test_file)
  File "utils/check_repo.py", line 163, in find_tested_models
    with open(os.path.join(PATH_TO_TESTS, test_file)) as f:
FileNotFoundError: [Errno 2] No such file or directory: 'tests/test_modeling_flax_utils.py'
Makefile:25: recipe for target 'extra_quality_checks' failed
make: *** [extra_quality_checks] Error 1

This is with the latest master.

@stas00
Copy link
Contributor

stas00 commented Oct 19, 2020

PR with fix #7914

The question is - why CI didn't fail? It reports no problem here:
https://app.circleci.com/pipelines/github/huggingface/transformers/14040/workflows/6cd2b931-ce7e-4e99-b313-4a34326fcece/jobs/101513

Once I got this fixed, 2 more issues came up:

python utils/check_repo.py
2020-10-19 12:22:10.636984: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcudart.so.10.1
Checking all models are properly tested.
Traceback (most recent call last):
  File "utils/check_repo.py", line 328, in <module>
    check_repo_quality()
  File "utils/check_repo.py", line 322, in check_repo_quality
    check_all_models_are_tested()
  File "utils/check_repo.py", line 217, in check_all_models_are_tested
    raise Exception(f"There were {len(failures)} failures:\n" + "\n".join(failures))
Exception: There were 2 failures:
test_modeling_flax_bert.py should define `all_model_classes` to apply common tests to the models it tests. If this intentional, add the test filename to `TEST_FILES_WITH_NO_COMMON_TESTS` in the file `utils/check_repo.py`.
test_modeling_flax_roberta.py should define `all_model_classes` to apply common tests to the models it tests. If this intentional, add the test filename to `TEST_FILES_WITH_NO_COMMON_TESTS` in the file `utils/check_repo.py`.
Makefile:25: recipe for target 'extra_quality_checks' failed

Fixed in the same PR.

stas00 added a commit to stas00/transformers that referenced this pull request Oct 19, 2020
Unless, this is actually a problem, this adds `modeling_flax_utils` to ignore list. otherwise currently it expects to have a 'tests/test_modeling_flax_utils.py' for it.
for context please see: huggingface#3722 (comment)
@stas00 stas00 mentioned this pull request Oct 19, 2020
LysandreJik pushed a commit that referenced this pull request Oct 20, 2020
* [flax] fix repo_check

Unless, this is actually a problem, this adds `modeling_flax_utils` to ignore list. otherwise currently it expects to have a 'tests/test_modeling_flax_utils.py' for it.
for context please see: #3722 (comment)

* fix 2 more issues

* merge #7919
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.