-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Description
🚀 The feature, motivation and pitch
torch.func.functional_call(module, parameter_and_buffer_dicts, args, kwargs=None, *, tie_weights=True, strict=False)
I do not see why args is a strictly required argument. In huggingface transformers, the dataloader will load the data in a dictionary, and usually we call the model forward function like this:
for inputs in eval_dataloader:
outputs = model(**inputs)
However, we cannot directly do this using torch.func.functional_call.
I think it is more reasonable to make args an optional argument with default value None. If necessary, we could add a post-check to ensure args and kwargs are not both None. I am not sure whether it has specific reasons to make args not optional, so I raise an issue instead of a PR.
Alternatives
Currently, I am using the following workaround:
outputs = torch.func.functional_call(model, params, (), kwargs=inputs)
Use an empty tuple to serve as a placeholder of args.
Additional context
This issue was raised when I was adapting the codes from the ACL 2024 best paper "Why are Sensitive Functions Hard for Transformers?"