Skip to content

Issue with FSDP + HuggingFace generate #100069

@dakinggg

Description

@dakinggg

🐛 Describe the bug

Calling .generate on a HuggingFace model that has been FSDP wrapped results in an error. I was able to work around this error by summoning full params without recurse, which just summons the LM head and avoids the issue.

Script with a minimal(ish) repro:

import torch
import transformers

from composer.utils import dist

def _auto_wrap_policy(module: torch.nn.Module, recurse: bool, nonwrapped_numel: int) -> bool:
    if recurse:
        return True
    if hasattr(module, '_fsdp_wrap'):
        return bool(module._fsdp_wrap)
    return False

def main():
    # initialize dist
    dist.initialize_dist(None)

    # load base model and tokenizer from Hugging Face
    gpt = transformers.AutoModelForCausalLM.from_pretrained('EleutherAI/gpt-neo-125m')
    gptt = transformers.AutoTokenizer.from_pretrained('EleutherAI/gpt-neo-125m')

    from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP

    # This seems to cause other problems...
    # for module in gpt.modules():
    #     module._fsdp_wrap = True

    gpt._fsdp_wrap = True
    
    # move model to gpu
    gpt.to(torch.cuda.current_device())
    # FSDP wrap
    fsdp_wrapped_gpt = FSDP(gpt, auto_wrap_policy=_auto_wrap_policy, use_orig_params=False)
    print(fsdp_wrapped_gpt)

    # create the input
    input_dict = gptt('hello', return_tensors='pt')
    input_dict['input_ids'] = input_dict['input_ids'].to(torch.cuda.current_device())
    input_dict['attention_mask'] = input_dict['attention_mask'].to(torch.cuda.current_device())

    # THIS CODE IS NECESSARY IN ORDER FOR .generate TO NOT ERROR BELOW (THIS WAS A PREVIOUS WORKAROUND FROM TORCH 1.13 THAT STILL SEEMS TO BE NECESSARY)
    with torch.no_grad():
        fsdp_wrapped_gpt.forward(input_ids=input_dict['input_ids'])

    # call generate
    generation = fsdp_wrapped_gpt.generate(input_ids=input_dict['input_ids'], attention_mask=input_dict['attention_mask'], max_new_tokens=5)
    print(generation)

if __name__ == '__main__':
    main()

resulting error:

Traceback (most recent call last):
  File "/mnt/workdisk/danielking/github/composer/scripts/fsdp_gen_repro.py", line 49, in <module>
    main()
  File "/mnt/workdisk/danielking/github/composer/scripts/fsdp_gen_repro.py", line 45, in main
    generation = fsdp_wrapped_gpt.generate(input_ids=input_dict['input_ids'], attention_mask=input_dict['attention_mask'], max_new_tokens=5)
  File "/mnt/workdisk/danielking/miniconda3/envs/composer-dev-torch2/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/mnt/workdisk/danielking/miniconda3/envs/composer-dev-torch2/lib/python3.10/site-packages/transformers/generation/utils.py", line 1437, in generate
    return self.greedy_search(
  File "/mnt/workdisk/danielking/miniconda3/envs/composer-dev-torch2/lib/python3.10/site-packages/transformers/generation/utils.py", line 2248, in greedy_search
    outputs = self(
  File "/mnt/workdisk/danielking/miniconda3/envs/composer-dev-torch2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/mnt/workdisk/danielking/miniconda3/envs/composer-dev-torch2/lib/python3.10/site-packages/transformers/models/gpt_neo/modeling_gpt_neo.py", line 741, in forward
    transformer_outputs = self.transformer(
  File "/mnt/workdisk/danielking/miniconda3/envs/composer-dev-torch2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/mnt/workdisk/danielking/miniconda3/envs/composer-dev-torch2/lib/python3.10/site-packages/transformers/models/gpt_neo/modeling_gpt_neo.py", line 578, in forward
    inputs_embeds = self.wte(input_ids)
  File "/mnt/workdisk/danielking/miniconda3/envs/composer-dev-torch2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/mnt/workdisk/danielking/miniconda3/envs/composer-dev-torch2/lib/python3.10/site-packages/torch/nn/modules/sparse.py", line 162, in forward
    return F.embedding(
  File "/mnt/workdisk/danielking/miniconda3/envs/composer-dev-torch2/lib/python3.10/site-packages/torch/nn/functional.py", line 2210, in embedding
    return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
RuntimeError: 'weight' must be 2-D

workaround is to wrap the .generate call with with FSDP.summon_full_params(self.model, writeback=False, recurse=False):

Versions

Collecting environment information...
PyTorch version: 2.0.0+cu117
Is debug build: False
CUDA used to build PyTorch: 11.7
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.5 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0
Clang version: Could not collect
CMake version: version 3.26.3
Libc version: glibc-2.31

Python version: 3.10.10 (main, Mar 21 2023, 18:45:11) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.4.0-137-generic-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: 11.7.99
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA A100-SXM4-40GB
GPU 1: NVIDIA A100-SXM4-40GB

Nvidia driver version: 515.48.07
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.5.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.5.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.5.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.5.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.5.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.5.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.5.0
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Byte Order: Little Endian
Address sizes: 48 bits physical, 48 bits virtual
CPU(s): 64
On-line CPU(s) list: 0-63
Thread(s) per core: 1
Core(s) per socket: 32
Socket(s): 2
NUMA node(s): 8
Vendor ID: AuthenticAMD
CPU family: 25
Model: 1
Model name: AMD EPYC 7513 32-Core Processor
Stepping: 1
Frequency boost: enabled
CPU MHz: 1777.500
CPU max MHz: 2600.0000
CPU min MHz: 1500.0000
BogoMIPS: 5200.14
Virtualization: AMD-V
L1d cache: 2 MiB
L1i cache: 2 MiB
L2 cache: 32 MiB
L3 cache: 256 MiB
NUMA node0 CPU(s): 0-7
NUMA node1 CPU(s): 8-15
NUMA node2 CPU(s): 16-23
NUMA node3 CPU(s): 24-31
NUMA node4 CPU(s): 32-39
NUMA node5 CPU(s): 40-47
NUMA node6 CPU(s): 48-55
NUMA node7 CPU(s): 56-63
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Retbleed: Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Retpolines, IBPB conditional, IBRS_FW, STIBP disabled, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 invpcid_single hw_pstate ssbd mba ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 invpcid cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local clzero irperf xsaveerptr wbnoinvd arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold v_vmsave_vmload vgif umip pku ospke vaes vpclmulqdq rdpid overflow_recov succor smca

Versions of relevant libraries:
[pip3] numpy==1.24.2
[pip3] pytorch-ranger==0.1.1
[pip3] torch==2.0.0+cu117
[pip3] torch-optimizer==0.3.0
[pip3] torchdata==0.6.0
[pip3] torchmetrics==0.11.3
[pip3] torchtext==0.15.1+cpu
[pip3] torchvision==0.15.1+cu117
[pip3] triton==2.0.0
[pip3] vit-pytorch==0.35.8
[conda] numpy 1.24.2 pypi_0 pypi
[conda] pytorch-ranger 0.1.1 pypi_0 pypi
[conda] torch 2.0.0+cu117 pypi_0 pypi
[conda] torch-optimizer 0.3.0 pypi_0 pypi
[conda] torchdata 0.6.0 pypi_0 pypi
[conda] torchmetrics 0.11.3 pypi_0 pypi
[conda] torchtext 0.15.1+cpu pypi_0 pypi
[conda] torchvision 0.15.1+cu117 pypi_0 pypi
[conda] triton 2.0.0 pypi_0 pypi
[conda] vit-pytorch 0.35.8 pypi_0 pypi

cc @zhaojuanmao @mrshenli @rohan-varma @awgu

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: fsdptriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions