-
Notifications
You must be signed in to change notification settings - Fork 31.5k
Integrate Bert-like model on Flax runtime. #3722
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
LysandreJik
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As I understand Flax is a neural-network library built on top of Jax, which brings automatic differentiation for python/numpy operations.
Do you mind walking me through why we use both here, and, more precisely, why we have a JaxPreTrainedModel as a parent of FlaxXXX models? Is the JaxPreTrainedModel's purpose to be able to accommodate models built on top of it from different jax-based libraries, or is it that it only depends on Jax operations so there is no need for it to be Flax-based?
Other than that, it looks like a great first approach! I guess we would need to add a few features as we go on - but it's impressive that you got it working so fast, with the same output between PT/TF/Flax!
|
|
||
| # Models are loaded from Pytorch checkpoints | ||
| BERT_PRETRAINED_MODEL_ARCHIVE_MAP = { | ||
| "bert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-pytorch_model.bin", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Really nice to be able to load models from their PyTorch checkpoints
|
|
||
| def __init__(self, config: BertConfig, state: dict, **kwargs): | ||
| self.config = config | ||
| self.key = PRNGKey(0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is this used for?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's mainly usage is related to all the stochastic operations, such as Dropout that I didn't put into the model for now, but might be pushed soon.
|
As you said, Jax is a library that interact with numpy to provide additional features: autodiff, auto-vectorization (vmap) and auto-parallelization (pmap). Jax is essentially stateless, which is reflected here through the function to differentiate (the model) doesn't holds the parameters. They have to be referenced somewhere else and feed somehow.
In that aspect, @madisonmay is currently working on a Haiku Bert integration in transformers. My hope it to be able to share as many things as possible between the two implementations (but can't be sure for now) |
|
Alright, that makes sense. Thanks for the explanation. |
LysandreJik
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool, it's starting to look really good! I just have a few questions regarding conversion/model architecture.
thomwolf
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok looks really great (Flax is quite pleasant to read)
Added a couple of remarks and questions. Happy to discuss them.
src/transformers/file_utils.py
Outdated
| USE_TF = os.environ.get("USE_TF", "AUTO").upper() | ||
| USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper() | ||
| if USE_TORCH in ("1", "ON", "YES", "AUTO") and USE_TF not in ("1", "ON", "YES"): | ||
| if USE_TORCH in ENV_VARS_TRUE_VALUES and USE_TF not in ("1", "ON", "YES"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not the same for the end of the line? (USE_TF)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, why is _torch_available dependent of USE_TF?
src/transformers/file_utils.py
Outdated
| USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper() | ||
|
|
||
| if USE_TF in ("1", "ON", "YES", "AUTO") and USE_TORCH not in ("1", "ON", "YES"): | ||
| if USE_TF in ENV_VARS_TRUE_VALUES and USE_TORCH not in ("1", "ON", "YES"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here as well, why not the same for the end of the line?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here, why do we test USE_TORCH for _tf_available?
| config (:class:`~transformers.PretrainedConfig`): | ||
| The model class to instantiate is selected based on the configuration class: | ||
| - isInstance of `roberta` configuration class: :class:`~transformers.RobertaModel` (RoBERTa model) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FlaxRobertaModel
| The model class to instantiate is selected based on the configuration class: | ||
| - isInstance of `roberta` configuration class: :class:`~transformers.RobertaModel` (RoBERTa model) | ||
| - isInstance of `bert` configuration class: :class:`~transformers.BertModel` (Bert model) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FlaxBertModel
| self._module = module | ||
|
|
||
| # Those are public as their type is generic to every derived classes. | ||
| self.key = PRNGKey(0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should we have the PRNGKey seed exposed as a model args so that users can have (and control) several weight initialization seeds.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1. I think it can be really useful to be able to configure the random seed.
| super().__init__(config, model_def, state) | ||
|
|
||
| @property | ||
| def module(self) -> BertModel: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BertModel or FlaxBertModel?
|
|
||
| @property | ||
| def config(self) -> BertConfig: | ||
| return self._config |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't these module and config properties be in the base class?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, that was the initial impl, but in that case we cannot have the correct return type for the config and model which are model dependant.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(nit) We could return a PreTrainedConfig for the configuration, which would be complete enough imo. This comment does not apply to the module though.
| @property | ||
| def config(self) -> RobertaConfig: | ||
| return self._config |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here, shouldn't this be in the base class?
| if token_type_ids is None: | ||
| token_type_ids = np.ones_like(input_ids) | ||
|
|
||
| if position_ids is None: | ||
| position_ids = np.arange( | ||
| self.config.pad_token_id + 1, | ||
| np.atleast_2d(input_ids).shape[-1] + self.config.pad_token_id + 1 | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in the FlaxBertModel these parameters are created ouside of the jited predict().
Any reason it's different here?
Should we standardize on one practice to make it easier to read for the user?
| config_class = None | ||
| pretrained_model_archive_map = {} | ||
| base_model_prefix = "" | ||
| model_class = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is model_class used for?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
model_class is the underlying flax.nn.Module class which is required by Flax/msgpack to allocated all the buffers.
| class BertIntermediate(nn.Module): | ||
|
|
||
| def apply(self, hidden_state, output_size: int): | ||
| # TODO: Had ACT2FN reference to change activation function |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: typo?
| vocab_size: int, hidden_size: int, type_vocab_size: int, max_length: int): | ||
|
|
||
| # Embed | ||
| w_emb = BertEmbedding(jnp.atleast_2d(input_ids.astype('i4')), vocab_size, hidden_size, name="word_embeddings") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: maybe split these in 2 lines to make it less wide?
| self._module = module | ||
|
|
||
| # Those are public as their type is generic to every derived classes. | ||
| self.key = PRNGKey(0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1. I think it can be really useful to be able to configure the random seed.
| def config(self) -> RobertaConfig: | ||
| return self._config | ||
|
|
||
| def __call__(self, input_ids, token_type_ids=None, position_ids=None, attention_mask=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How often is this called? RIght now it looks like you @jit the predict graph each time the function is called. Is that the intention? Can you define&compile the predict function elsewhere?
b259000 to
27e9bc5
Compare
| return self._config | ||
|
|
||
| def __call__(self, input_ids, token_type_ids=None, position_ids=None, attention_mask=None): | ||
| @jax.jit |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like you should move this outside of the call.
So everything that you want to have @jax.jit'ed , so that when it is called multiple times it is cached.
Right now this gets compiled on each call to call, which should cause a lot of compilation overhead.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah you definitely don't want to use jit inside a function like this, unless perhaps it's only called once anyway.
|
This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions. |
|
Unstale |
Codecov Report
@@ Coverage Diff @@
## master #3722 +/- ##
==========================================
+ Coverage 78.32% 80.88% +2.55%
==========================================
Files 187 165 -22
Lines 37162 30383 -6779
==========================================
- Hits 29107 24575 -4532
+ Misses 8055 5808 -2247
Continue to review full report at Codecov.
|
a1bedcd to
23703a5
Compare
| return jax.lax.tanh(out) | ||
|
|
||
|
|
||
| class BertModel(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comparing your implementation to our lm1b example (https://github.com/google/flax/blob/master/examples/lm1b/models.py), it seems your code contains significantly more nn.Module abstractions. Why did you decide to create them? Is this to ensure the translation from a Pytorch model was easier?
I personally prefer fewer abstraction, since I think it would make the code more concise and easier to digest, but i'd be interested in hearing what your thoughts are!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As you guessed, I choose to follow what is already widely adopted in the library regarding modules fragmentation.
This is something our users are welcoming because they like how easy it is to tweak one module. Also it makes almost a 1-1 matchs with both PyTorch & TensorFlow implementations, so it might be easier for users who would like to give it a try to do the move.
Does it make sense?
| BERT implementation using JAX/Flax as backend | ||
| """ | ||
|
|
||
| model_class = BertModel |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I personally think the names FlaxBertModel and BertModel may be a bit confusing, since in fact it seems BertModel is a nn.Module and in that sense more like a Flax model than the other one. Perhaps you could rename BertModel to BertModule, to clarify this is in fact the Flax module, and the other one is a wrapper around it?
|
cc @levskaya |
src/transformers/file_utils.py
Outdated
| from jax.config import config | ||
| # TODO(marcvanzee): Flax Linen requires JAX omnistaging. Remove this | ||
| # once JAX enables it by default. | ||
| config.enable_omnistaging() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we cut a new release of flax and pin to newer jax/flax versions this is no longer necessary as it's now (recently) the default.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed, I'll add a suggestion.
| return self._config | ||
|
|
||
| def __call__(self, input_ids, token_type_ids=None, position_ids=None, attention_mask=None): | ||
| @jax.jit |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah you definitely don't want to use jit inside a function like this, unless perhaps it's only called once anyway.
marcvanzee
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great!!
src/transformers/file_utils.py
Outdated
| from jax.config import config | ||
| # TODO(marcvanzee): Flax Linen requires JAX omnistaging. Remove this | ||
| # once JAX enables it by default. | ||
| config.enable_omnistaging() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed, I'll add a suggestion.
Signed-off-by: Morgan Funtowicz <[email protected]>
Signed-off-by: Morgan Funtowicz <[email protected]>
Signed-off-by: Morgan Funtowicz <[email protected]>
Signed-off-by: Morgan Funtowicz <[email protected]>
Signed-off-by: Morgan Funtowicz <[email protected]>
Signed-off-by: Morgan Funtowicz <[email protected]>
Signed-off-by: Morgan Funtowicz <[email protected]>
Signed-off-by: Morgan Funtowicz <[email protected]>
Signed-off-by: Morgan Funtowicz <[email protected]>
Signed-off-by: Morgan Funtowicz <[email protected]>
Signed-off-by: Morgan Funtowicz <[email protected]>
Signed-off-by: Morgan Funtowicz <[email protected]>
Signed-off-by: Morgan Funtowicz <[email protected]>
Signed-off-by: Morgan Funtowicz <[email protected]>
…torch equivalence. Signed-off-by: Morgan Funtowicz <[email protected]>
Signed-off-by: Morgan Funtowicz <[email protected]>
LysandreJik
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great, thanks for iterating @mfuntowicz!
|
It looks like a file is missing: Shouldn't the CI have caught this? |
|
Looks like a problem in |
|
Nope, both run the same sub-target: This is with the latest master. |
|
PR with fix #7914 The question is - why CI didn't fail? It reports no problem here: Once I got this fixed, 2 more issues came up: Fixed in the same PR. |
Unless, this is actually a problem, this adds `modeling_flax_utils` to ignore list. otherwise currently it expects to have a 'tests/test_modeling_flax_utils.py' for it. for context please see: huggingface#3722 (comment)
* [flax] fix repo_check Unless, this is actually a problem, this adds `modeling_flax_utils` to ignore list. otherwise currently it expects to have a 'tests/test_modeling_flax_utils.py' for it. for context please see: #3722 (comment) * fix 2 more issues * merge #7919
This Pull Request attempts to bring support for Flax framework as part of transformers.
Main focus as been put on providing BERT-like models, principally by making it possible to load PyTorch checkpoints and doing the necessary conversions (few) directly on the fly. Supports also providing a msgpack formatted file from Flax.
save_pretrainedwill save the model through msgpack format to avoid dependency on torch inside Jax code.Targeted models:
If not too hard