Skip to content

cpp extensions should use full schema string #21416

@nairbv

Description

@nairbv

The extensions generated by function_wrapper.py use the types and argument names, not the full schema string (schema_args in create_extension_backend ignores is_nullable, should probably use full schema string). This means that optional arguments aren't properly set/handled.

E.g., MSNPUType.cpp generates a function:

Tensor MSNPUType::norm(const Tensor & self, c10::optional<Scalar> p, ScalarType dtype) const {
    return MSNPUTypeDispatch::get_function<Tensor (*)(const Tensor &, c10::optional<Scalar>, ScalarType)>("norm(Tensor self, Scalar p, ScalarType dtype) -> Tensor")(self, p, dtype);
}

Since p is optional, the generated schema string should be something like Scalar? p or optional<Scalar> p, not simply Scalar p. This can result in parameters set to invalid values (eg outside the range of a valid enumeration).

To test, we should be able to register a function with an optional argument in the tests in msnpu_extensions.cpp, run with python run_test.py -i cpp_extensions

This issue arose when updating msnpu_extension tests in:
#21088

Metadata

Metadata

Assignees

Labels

module: cpp-extensionsRelated to torch.utils.cpp_extensiontriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions