-
Notifications
You must be signed in to change notification settings - Fork 31.5k
Handle PyTorch to Flax conversion of 1D convolutions #15519
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
The documentation is not available anymore as the PR was closed or merged. |
patil-suraj
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thanks for handling this!
Note for the future: We should try to find a more robust way of detecting Conv layers.
Also cc @patrickvonplaten
|
I'm a bit surprised that we needed that. We already had 1D Conv layers in Flax in Wav2Vec2 and the conversion worked |
|
|
||
| # conv1d layer | ||
| renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",) | ||
| if pt_tuple_key[-1] == "weight" and pt_tensor.ndim == 3 and not is_key_or_prefix_key_in_dict(pt_tuple_key): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What we do here is equivalent to what is done below for # linear layer - I don't understand why we've added this. Was some code failing before?
pt_tensor.transpose(2, 1, 0) is the same as pt_tensor.T and if pt_tuple_key[-1] == "weight" and pt_tensor.ndim == 3 and not is_key_or_prefix_key_in_dict(pt_tuple_key) is true then pt_tuple_key[-1] == "weight" and not is_key_or_prefix_key_in_dict(pt_tuple_key) is also True.
We should avoid at all cost to add more complexity to those weight conversion statements and keep things as simple as possible. Remember that all PT<->Flax conversions depend on this code and we should not clutter be extra careful here. If we add a new statement like this, there has to be at least a test that ensures that this change is needed. At the moment I cannot think of a single use case where this is the case -> we already had 1D-conv layer conversions working for Flax Wav2Vec2 <-> PT Wav2Vec2
=> IMO we should revert this PR. The comments could indeed be improved however, so I'm happy to change the comment # conv layer to # 2D - conv layer and # linear layer to # linear and 1D - conv layer
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Completely agree, we actually don't need this. My bad, not my best review. Thanks a lot!
This reverts commit 854a0d5.
…face#15519)" (huggingface#15540) This reverts commit 854a0d5.
What does this PR do?
Currently, only 2-dimensional convolutional layers are renamed and reshaped in the PyTorch to Flax conversion script. This PR handles the case of 1-dimensional convolutions layers, in an entirely equivalent way to their 2-dimensional counterparts.