Skip to content

Conversation

@sanchit-gandhi
Copy link
Contributor

@sanchit-gandhi sanchit-gandhi commented Mar 9, 2022

This PR makes two changes to each of FlaxEncoderDecoderModel and FlaxSpeechEncoderDecoderModel:

  1. Amends the input docstrings to remove incorrect information about the model "shifting tokens right for denoising". In Flax, decoder_input_ids are obtained by shifting the target labels right outside of the seq2seq model, not within as stated in the docstrings.
  2. Raises a ValueError if decoder_input_ids are not provided. The current behaviour allows for decoder_input_ids to be omitted, in which case they default to None. This causes errors when decoder_input_ids=None is manipulated with JAX functions to build the decoder_attention_mask and decoder_position_ids should they be omitted from the arguments too.

The following code snippet throws the error aforementioned in 2:

from transformers import FlaxSpeechEncoderDecoderModel
import jax.numpy as jnp
model = FlaxSpeechEncoderDecoderModel.from_encoder_decoder_pretrained('hf-internal-testing/tiny-random-wav2vec2', 'hf-internal-testing/tiny-random-gpt2', encoder_from_pt=True, decoder_from_pt=True)
inputs = jnp.ones((2, 5000), dtype=jnp.float32)
outputs = model(inputs)

Output:

Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/Users/sanchitgandhi/transformers/src/transformers/models/speech_encoder_decoder/modeling_flax_speech_encoder_decoder.py", line 688, in __call__
    decoder_attention_mask = jnp.ones_like(decoder_input_ids)
  File "/Users/sanchitgandhi/venv/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 3706, in ones_like
    _check_arraylike("ones_like", a)
  File "/Users/sanchitgandhi/venv/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 570, in _check_arraylike
    raise TypeError(msg.format(fun_name, type(arg), pos))
TypeError: ones_like requires ndarray or scalar arguments, got <class 'NoneType'> at position 0.

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Mar 9, 2022

The documentation is not available anymore as the PR was closed or merged.

@mishig25 mishig25 mentioned this pull request Mar 9, 2022
5 tasks
Copy link
Contributor

@patil-suraj patil-suraj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for fixing this!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe write it this way

Suggested change
"For sequence to sequence training, `decoder_position_ids` must be specified as an input argument."
"`decoder_input_ids` can not be `None`. For sequence to sequence training, `decoder_input_ids` must be specified as an input argument."

Comment on lines -111 to -112
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Think we should keep this comment but just mention that this needs to be done outside of the model and does not happen automatically. Same for the FlaxEncoderDecoder model.

Comment on lines 684 to 688
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment as above.

@sanchit-gandhi sanchit-gandhi merged commit 741e493 into huggingface:master Mar 10, 2022
@sanchit-gandhi sanchit-gandhi deleted the flax-enc-dec branch March 10, 2022 16:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants