Skip to content

Make the arguments in torch.func.functional_call optional #134408

@why-in-Shanghaitech

Description

@why-in-Shanghaitech

🚀 The feature, motivation and pitch

torch.func.functional_call(module, parameter_and_buffer_dicts, args, kwargs=None, *, tie_weights=True, strict=False)

I do not see why args is a strictly required argument. In huggingface transformers, the dataloader will load the data in a dictionary, and usually we call the model forward function like this:

for inputs in eval_dataloader:
    outputs = model(**inputs)

However, we cannot directly do this using torch.func.functional_call.

I think it is more reasonable to make args an optional argument with default value None. If necessary, we could add a post-check to ensure args and kwargs are not both None. I am not sure whether it has specific reasons to make args not optional, so I raise an issue instead of a PR.

Alternatives

Currently, I am using the following workaround:

outputs = torch.func.functional_call(model, params, (), kwargs=inputs)

Use an empty tuple to serve as a placeholder of args.

Additional context

This issue was raised when I was adapting the codes from the ACL 2024 best paper "Why are Sensitive Functions Hard for Transformers?"

cc @zou3519 @Chillee @samdow @kshitij12345 @janeyx99

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: functorchPertaining to torch.func or pytorch/functorchtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions