-
Notifications
You must be signed in to change notification settings - Fork 31.8k
Description
System Info
transformersversion: 4.36.2- Platform: Linux-5.15.0-91-generic-x86_64-with-glibc2.35
- Python version: 3.10.13
- Huggingface_hub version: 0.19.4
- Safetensors version: 0.4.0
- Accelerate version: 0.25.0
- Accelerate config: - compute_environment: LOCAL_MACHINE
- distributed_type: DEEPSPEED
- mixed_precision: bf16
- use_cpu: False
- debug: False
- num_processes: 8
- machine_rank: 0
- num_machines: 1
- rdzv_backend: static
- same_network: True
- main_training_function: main
- deepspeed_config: {'gradient_accumulation_steps': 1, 'offload_optimizer_device': 'none', 'offload_param_device': 'none', 'zero3_init_flag': True, 'zero3_save_16bit_model': False, 'zero_stage': 3}
- downcast_bf16: no
- tpu_use_cluster: False
- tpu_use_sudo: False
- tpu_env: [] - PyTorch version (GPU?): 2.1.1+cu121 (True)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): 0.7.5 (cpu)
- Jax version: 0.4.21
- JaxLib version: 0.4.21
- Using GPU in script?: Yes
- Using distributed or parallel set-up in script?: Yes
Who can help?
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examplesfolder (such as GLUE/SQuAD, ...) - My own task or dataset (give details below)
Reproduction
Training any text model with gradient checkpointing enabled on PyTorch 2.1 and higher produces this warning:
/scratch/miniconda3/envs/brr/lib/python3.10/site-packages/torch/utils/checkpoint.py:429: Warning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.
This can be resolved by manually monkey-patching the model code with use_reentrant=True, eg. like so:
hidden_states, self_attns, decoder_cache = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states,
attention_mask,
position_ids,
None,
is_padded_inputs,
use_reentrant=True,
)
This is caused by an upstream change in PyTorch:
Expected behavior
No warning should be written
Metadata
Metadata
Assignees
Labels
No labels