-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Description
The extensions generated by function_wrapper.py use the types and argument names, not the full schema string (schema_args in create_extension_backend ignores is_nullable, should probably use full schema string). This means that optional arguments aren't properly set/handled.
E.g., MSNPUType.cpp generates a function:
Tensor MSNPUType::norm(const Tensor & self, c10::optional<Scalar> p, ScalarType dtype) const {
return MSNPUTypeDispatch::get_function<Tensor (*)(const Tensor &, c10::optional<Scalar>, ScalarType)>("norm(Tensor self, Scalar p, ScalarType dtype) -> Tensor")(self, p, dtype);
}
Since p is optional, the generated schema string should be something like Scalar? p or optional<Scalar> p, not simply Scalar p. This can result in parameters set to invalid values (eg outside the range of a valid enumeration).
To test, we should be able to register a function with an optional argument in the tests in msnpu_extensions.cpp, run with python run_test.py -i cpp_extensions
This issue arose when updating msnpu_extension tests in:
#21088