Skip to content

Conversation

@sanchit-gandhi
Copy link
Contributor

What does this PR do?

Currently, only 2-dimensional convolutional layers are renamed and reshaped in the PyTorch to Flax conversion script. This PR handles the case of 1-dimensional convolutions layers, in an entirely equivalent way to their 2-dimensional counterparts.

@HuggingFaceDocBuilder
Copy link

HuggingFaceDocBuilder commented Feb 4, 2022

The documentation is not available anymore as the PR was closed or merged.

Copy link
Contributor

@patil-suraj patil-suraj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks for handling this!

Note for the future: We should try to find a more robust way of detecting Conv layers.

Also cc @patrickvonplaten

@patil-suraj patil-suraj merged commit 854a0d5 into huggingface:master Feb 4, 2022
@patrickvonplaten
Copy link
Contributor

I'm a bit surprised that we needed that. We already had 1D Conv layers in Flax in Wav2Vec2 and the conversion worked


# conv1d layer
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
if pt_tuple_key[-1] == "weight" and pt_tensor.ndim == 3 and not is_key_or_prefix_key_in_dict(pt_tuple_key):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What we do here is equivalent to what is done below for # linear layer - I don't understand why we've added this. Was some code failing before?

pt_tensor.transpose(2, 1, 0) is the same as pt_tensor.T and if pt_tuple_key[-1] == "weight" and pt_tensor.ndim == 3 and not is_key_or_prefix_key_in_dict(pt_tuple_key) is true then pt_tuple_key[-1] == "weight" and not is_key_or_prefix_key_in_dict(pt_tuple_key) is also True.

We should avoid at all cost to add more complexity to those weight conversion statements and keep things as simple as possible. Remember that all PT<->Flax conversions depend on this code and we should not clutter be extra careful here. If we add a new statement like this, there has to be at least a test that ensures that this change is needed. At the moment I cannot think of a single use case where this is the case -> we already had 1D-conv layer conversions working for Flax Wav2Vec2 <-> PT Wav2Vec2

=> IMO we should revert this PR. The comments could indeed be improved however, so I'm happy to change the comment # conv layer to # 2D - conv layer and # linear layer to # linear and 1D - conv layer

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Completely agree, we actually don't need this. My bad, not my best review. Thanks a lot!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants