Skip to content

Conversation

@ydshieh
Copy link
Collaborator

@ydshieh ydshieh commented Jan 28, 2022

What does this PR do?

In TFxxxForForConditionalGeneration models, there is

        if inputs["labels"] is not None:
            ...
            inputs["use_cache"] = False

while there is no such change in xxxForForConditionalGeneration models.

This would fail a more complete pt/tf equivalence test.
I understand why TF models doing this (which might be beneficial for efficiency.

We can do it in another way: add use_cache=False for PyTorch models if labels is passed. Let me know which way HF prefers.

@HuggingFaceDocBuilder
Copy link

HuggingFaceDocBuilder commented Jan 28, 2022

The documentation is not available anymore as the PR was closed or merged.

@ydshieh ydshieh changed the title Remove inputs["use_cache"] = False Remove use_cache=False in TFxxxForForConditionalGeneration Jan 28, 2022
@ydshieh ydshieh changed the title Remove use_cache=False in TFxxxForForConditionalGeneration Remove use_cache=False in TFxxxForConditionalGeneration Jan 28, 2022
@patrickvonplaten
Copy link
Contributor

Hmm, I think we can leave those statements @ydshieh no? The don't do any harm, but reduce the memory footprint slightly by not creating the past tensor. The reduced memory footprint might however be neglectible - in this case, ok for me to remove it.

For more context, I think the reason we've added use_cache=False here is because:

  • use_cache can be used to sped up generation by not re-computing the past key value states for previous generation steps. This however does not make sense during training because there is only one forward pass through the decoder.
  • use_cache is set to True by default, which means that the model outputs a tuple of past key value states that can (but don't have to) be used for sped-up generation if they are passed to the forward function in the next step. Since we are never passing past_key_values to the forward function during training, it's not really an error to leave use_cache to True.
  • The reason why we however do set it to False is because this way some memory can be saved since the past_key_values are never actually returned.

@patrickvonplaten
Copy link
Contributor

Keen to hear @Rocketknight1 and @gante opinion here

@ydshieh
Copy link
Collaborator Author

ydshieh commented Jan 31, 2022

I totally agree. My purpose it to make the PT/TF returns the same results, not the specific way to use.
Currently, PT returns past_key_values, and TF don't in this case. This makes a more thorough version of test_pt_tf_model_equivalence almost impossible (or very difficult).

I can do it in another way: keep TF code as it is, but add use_cache=False for PT models if labels is passed (so PT/TF works the same way).

I believe a thorough test will be beneficial to HF's transformers - I have use it to identify several other issues that require real fixes (unlike this one).

@gante
Copy link
Contributor

gante commented Jan 31, 2022

Given the mismatches that @ydshieh has found so far, I'm leaning towards approving the PR so we can do more thorough testing.

Can we get some numbers (memory utilization and runtime, before and after the change), so we can make an informed decision? If the difference turns out to be small, then I believe it is a no-brainer :)

@ydshieh
Copy link
Collaborator Author

ydshieh commented Jan 31, 2022

Given the mismatches that @ydshieh has found so far, I'm leaning towards approving the PR so we can do more thorough testing.

Can we get some numbers (memory utilization and runtime, before and after the change), so we can make an informed decision? If the difference turns out to be small, then I believe it is a no-brainer :)

I could try to measure it. But I prefer to go another way: add use_cache = False to PT models if labels is passed.
(no matter if there is gain or not). It's just a few files, and if cache is never used when labels is there, no reason to create it I think. Would it be OK for you guys?

@ydshieh
Copy link
Collaborator Author

ydshieh commented Jan 31, 2022

Hi,

Use a setting batch_size = 16, seq_len = 128 (both encoder/decoder) for BART large: the difference (on the screen) of with/without cache is about 512 MB.

I calculated the number of float32: 12 (layers) * 4 (2 for decoder attn, 2 for cross attn) * 16 (batch) * 16 (heads) * 128 (seq len) * 64 (dim per head) * 4 Byte (float32) = 384 MB. (not very sure why this is different from the above observed difference.).

I tried with other settings, and it is linear. So for batch_size = 256, seq_len = 512, the difference would be like 32 GB.

I am using PT's BART in the experimentation.

@gante
Copy link
Contributor

gante commented Feb 1, 2022

The savings could be helpful, especially on larger models. I'm in favour of setting use_cache to False when labels is passed, since we can assume we are in a training regime -- but with a warnings.warn if use_cache is True, so we know we are overwriting an input option.

@ydshieh ydshieh force-pushed the fix_missing_cache_in_tf_models branch from d5111ab to 43c8452 Compare February 1, 2022 18:04
@ydshieh
Copy link
Collaborator Author

ydshieh commented Feb 1, 2022

I reverted the changes done in TF models, and add use_cache = False for related PT models if labels is passed.

@ydshieh
Copy link
Collaborator Author

ydshieh commented Feb 2, 2022

I think we should give a warning only if use_cache as argument is True (when labels is provided).
If self.config.use_cache=True but use_cache (as arg) is not specified, we can safely change use_cache=False without warning, right?
(This way, we can avoid showing repeated warning due to the default self.config.use_cache=True).

Or, we shouldn't set use_cache=False if a user provides use_cache=True explicitly?

Since the changes are in PT models now, cc @LysandreJik @patrickvonplaten @sgugger

use_cache = use_cache if use_cache is not None else self.config.use_cache

@patrickvonplaten
Copy link
Contributor

I think we should give a warning only if use_cache as argument is True (when labels is provided). If self.config.use_cache=True but use_cache (as arg) is not specified, we can safely change use_cache=False without warning, right? (This way, we can avoid showing repeated warning due to the default self.config.use_cache=True).

Or, we shouldn't set use_cache=False if a user provides use_cache=True explicitly?

Since the changes are in PT models now, cc @LysandreJik @patrickvonplaten @sgugger

use_cache = use_cache if use_cache is not None else self.config.use_cache

Yes I very much agree that's a nice insight! If use_cache is passed as True I think we should still change use_cache=False but this time add a warning. If it's not passed no need to throw a warning to not get repeated warnings. Think then this PR is good for merge. Could you also change the title to something like "Force use_cache to be False in PyTorch"?

@ydshieh
Copy link
Collaborator Author

ydshieh commented Feb 7, 2022

I think we should give a warning only if use_cache as argument is True (when labels is provided). If self.config.use_cache=True but use_cache (as arg) is not specified, we can safely change use_cache=False without warning, right? (This way, we can avoid showing repeated warning due to the default self.config.use_cache=True).
Or, we shouldn't set use_cache=False if a user provides use_cache=True explicitly?
Since the changes are in PT models now, cc @LysandreJik @patrickvonplaten @sgugger

use_cache = use_cache if use_cache is not None else self.config.use_cache

Yes I very much agree that's a nice insight! If use_cache is passed as True I think we should still change use_cache=False but this time add a warning. If it's not passed no need to throw a warning to not get repeated warnings. Think then this PR is good for merge. Could you also change the title to something like "Force use_cache to be False in PyTorch"?

Hi, @patrickvonplaten

I will change the title, but I haven't done anything for the warning (if use_cache=True). Should I do it in this PR, or leave it to another one?

@ydshieh ydshieh changed the title Remove use_cache=False in TFxxxForConditionalGeneration Force use_cache to be False in PyTorch Feb 7, 2022
@patrickvonplaten
Copy link
Contributor

I think we should give a warning only if use_cache as argument is True (when labels is provided). If self.config.use_cache=True but use_cache (as arg) is not specified, we can safely change use_cache=False without warning, right? (This way, we can avoid showing repeated warning due to the default self.config.use_cache=True).
Or, we shouldn't set use_cache=False if a user provides use_cache=True explicitly?
Since the changes are in PT models now, cc @LysandreJik @patrickvonplaten @sgugger

use_cache = use_cache if use_cache is not None else self.config.use_cache

Yes I very much agree that's a nice insight! If use_cache is passed as True I think we should still change use_cache=False but this time add a warning. If it's not passed no need to throw a warning to not get repeated warnings. Think then this PR is good for merge. Could you also change the title to something like "Force use_cache to be False in PyTorch"?

Hi, @patrickvonplaten

I will change the title, but I haven't done anything for the warning (if use_cache=True). Should I do it in this PR, or leave it to another one?

Think we can do it in this PR

@ydshieh
Copy link
Collaborator Author

ydshieh commented Feb 7, 2022

Added warning. (We can add the same warning for TF models in a follow-up PR)

Copy link
Contributor

@patil-suraj patil-suraj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@patrickvonplaten
Copy link
Contributor

Great PR @ydshieh - thanks a lot!

@patrickvonplaten patrickvonplaten merged commit 6a5472a into huggingface:master Feb 8, 2022
@ydshieh ydshieh deleted the fix_missing_cache_in_tf_models branch February 8, 2022 15:43
stevhliu pushed a commit to stevhliu/transformers that referenced this pull request Feb 18, 2022
* use_cache = False for PT models if labels is passed

* Fix for BigBirdPegasusForConditionalGeneration

* add warning if users specify use_cache=True

* Use logger.warning instead of warnings.warn

Co-authored-by: ydshieh <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants