-
Notifications
You must be signed in to change notification settings - Fork 26.3k
fixes #38775 #38779: complex support for linspace and logspace #38875
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
| TORCH_CHECK(steps >= 0, "number of steps must be non-negative"); | ||
| Tensor result = at::empty({steps}, options); | ||
| auto result_options = options; | ||
| if (start.isComplex() || end.isComplex()){ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this behavior is not stated in the doc so doc should be updated as well. perhaps something on the lines:
dtype (torch.dtype, optional) – the desired data type of returned tensor. Default: if None, and start and end are real values, uses a global default (see torch.set_default_tensor_type()) else sets it to the complex type corresponding to the torch.get_default_tensor_type().
also as I mentioned in the issue, we should update the doc to state that start and can be float or complex
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure.
| const TensorOptions& options) { | ||
| TORCH_CHECK(steps >= 0, "number of steps must be non-negative"); | ||
| Tensor result = at::empty({steps}, options); | ||
| auto result_options = options; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also we should error out when the input dtype is not complex, but the start and end are. this is tricky here because the options would by default have a floating point dtype.
cc. @mruberry
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Which input dtype, the function only takes start, stop and steps as the main input to compute.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Understood now, please check updated the comment below this line and let me know if there is a graceful way to do it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a way to know if it was defaulted or user specified?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh. Thanks for the reference. Will have a look.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks @mruberry I think we can use a similar approach here as well. in fact, it might be useful to generalize the infer_full_options to make it usable at other places like here as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure if it can be done in a very generic way as ops do differ in their behaviour of inferring the type like the inputs that they consider, the types they default to.
(However, I don't know much about the same so my assumption above might be wrong).
For now have put the inference code as a function similar to infer_full_options.
💊 CI failures summary and remediationsAs of commit ea80fc6 (more details on the Dr. CI page):
🚧 1 fixed upstream failure:These were probably caused by upstream breakages that were already fixed.
Please rebase on the
|
test/test_torch.py
Outdated
| start = torch.randn(1, dtype=dtype).item() | ||
| end = (start + torch.randn(1, dtype=dtype) + random.randint(5, 15)).item() | ||
|
|
||
| # Crashing for the step values: [2, 3, 5, 11, 256, 257, 2**22] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have a suspicion on this line ( as the mismatch was past the half the elements).
| data_ptr[i] = std::pow(scalar_base, scalar_end - (step * static_cast<scalar_t>(steps - i - 1))); |
|
I am not sure how and where the below part is generated. There looks to be a discrepancy in how the From what I see in // linspace
static PyObject * THPVariable_linspace(PyObject* self_, PyObject* args, PyObject* kwargs)
{
HANDLE_TH_ERRORS
static PythonArgParser parser({
"linspace(Scalar start, Scalar end, int64_t steps=100, *, Tensor out=None, ScalarType dtype=None, Layout layout=torch.strided, Device device=None, bool pin_memory=False, bool requires_grad=False)",
}, /*traceable=*/true);
ParsedArgs<9> parsed_args;
auto _r = parser.parse(args, kwargs, parsed_args);
if(_r.has_torch_function()) {
return handle_torch_function(_r, args, kwargs, THPVariableFunctionsModule, "torch");
}
if (_r.isNone(3)) {
// aten::linspace(Scalar start, Scalar end, int steps=100, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
const auto options = TensorOptions()
.dtype(_r.scalartype(4))
.device(_r.device(6))
.layout(_r.layoutOptional(5))
.requires_grad(_r.toBool(8))
.pinned_memory(_r.toBool(7));
torch::utils::maybe_initialize_cuda(options);
auto dispatch_linspace = [](Scalar start, Scalar end, int64_t steps, const TensorOptions & options) -> Tensor {
pybind11::gil_scoped_release no_gil;
return torch::linspace(start, end, steps, options);
};
return wrap(dispatch_linspace(_r.scalar(0), _r.scalar(1), _r.toInt64(2), options));
} else {
// aten::linspace.out(Scalar start, Scalar end, int steps=100, *, Tensor(a!) out) -> Tensor(a!)
check_out_type_matches(_r.tensor(3), _r.scalartype(4),
_r.isNone(4), _r.layoutOptional(5),
_r.device(6), _r.isNone(6));
auto dispatch_linspace_out = [](Tensor out, Scalar start, Scalar end, int64_t steps) -> Tensor {
pybind11::gil_scoped_release no_gil;
return at::linspace_out(out, start, end, steps);
};
return wrap(dispatch_linspace_out(_r.tensor(3), _r.scalar(0), _r.scalar(1), _r.toInt64(2)).set_requires_grad(_r.toBool(8)));
}
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}However for static PyObject * THPVariable_full(PyObject* self, PyObject* args, PyObject* kwargs) {
HANDLE_TH_ERRORS
static PythonArgParser parser({
"full(IntArrayRef size, Scalar fill_value, *, Tensor out=None, ScalarType dtype=None, Layout layout=torch.strided, Device device=None, bool pin_memory=False, bool requires_grad=False)",
"full(IntArrayRef size, Scalar fill_value, *, DimnameList names=None, ScalarType dtype=None, Layout layout=torch.strided, Device device=None, bool pin_memory=False, bool requires_grad=False)",
}, /*traceable=*/true);
// Acquires (common) arguments
ParsedArgs<8> parsed_args;
auto r = parser.parse(args, kwargs, parsed_args);
if(r.has_torch_function()) {
return handle_torch_function(r, args, kwargs, THPVariableFunctionsModule, "torch");
}
auto size = r.intlist(0);
auto fill_val = r.scalar(1);
const auto options = TensorOptions{}
.dtype(r.scalartypeOptional(3))
.layout(r.layout(4))
.device(r.device(5))
.pinned_memory(r.toBool(6))
.requires_grad(r.toBool(7));
if (r.idx == 0) {
// full
if (r.isNone(2)) {
return wrap(dispatch_full(size, fill_val, options).set_requires_grad(r.toBool(7)));
}
// full.out
// Validates out tensor and other kwargs
auto result = r.tensor(2);
TORCH_CHECK(!r.toBool(6), " `pin_memory` and `out` parameters are incompatible");
check_out_type_matches(result, r.scalartype(3), r.isNone(3), r.layout(4),
r.device(5), r.isNone(5));
return wrap(dispatch_full(size, fill_val, result).set_requires_grad(r.toBool(7)));
} else if (r.idx == 1) {
// full.names
if (r.isNone(2)) {
return wrap(dispatch_full(size, fill_val, c10::nullopt, options).set_requires_grad(r.toBool(7)));
}
// Converts from c10::optional<std:vector...> to c10::optional<ArrayRef...>
auto raw_names = r.toDimnameListOptional(2);
c10::optional<DimnameList> names(*raw_names);
return wrap(dispatch_full(size, fill_val, names, options).set_requires_grad(r.toBool(7)));
}
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
|
|
@kshitij12345 let's limit the scope of PR to support complex only when the input dtype is complex. We can come back to this later when the default dtype issue is fixed more generally for all the other ops as well. |
|
also would you like to work |
Agreed.
Sure will try to give it a shot over the weekend. |
|
@anjali411 @mruberry Note: Lint error is irrelevant. Will rebase if no-more changes are necessary. |
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@anjali411 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
torch/_torch_docs.py
Outdated
| {dtype} | ||
| dtype (torch.dtype, optional) : the desired data type of returned tensor. Default: if None, | ||
| and start and end are real values, uses a global default (see torch.get_default_tensor_type()) | ||
| else sets it to the complex default type. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For clarity, I think we could re-write "the complex default type" as "the complex type corresponding to the global default type".
torch/_torch_docs.py
Outdated
| {dtype} | ||
| dtype (torch.dtype, optional) : the desired data type of returned tensor. Default: if None, | ||
| and start and end are real values, uses a global default (see torch.get_default_tensor_type()) | ||
| else sets it to the complex default type. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same as below
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@anjali411 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
test/test_tensor_creation_ops.py
Outdated
| self.assertTrue(t[steps - 1].item() == a[steps - 1]) | ||
|
|
||
| def test_linspace_vs_numpy_complex(self, device): | ||
| dtype = torch.complex64 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Accept a dtype argument use the @dtypes decorator with just torch.complex64
test/test_tensor_creation_ops.py
Outdated
| self.assertEqual(t[steps - 1], a[steps - 1]) | ||
|
|
||
| def test_logspace_vs_numpy_complex(self, device): | ||
| dtype = torch.complex64 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Analogous comments here as in previous test - seems like the tests could be combined?
| Scalar end, | ||
| const TensorOptions& options) { | ||
| auto result_options = options; | ||
| if (start.isComplex() || end.isComplex()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems odd. Would you elaborate on what happens in the following cases:
complex start, float end
float start, complex end
complex start, complex end
float start, float end, complex dtype or out
complex start, complex end, float dtype or out
dtype and out tensor type mismatch
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
import torch
def linspace_check(start, end, dtype=None, device='cpu', out=None):
if out is None:
return torch.linspace(start, end, dtype=dtype, device=device)
if out is not None:
return torch.linspace(start, end, out=out)
# Start: Imag
# End: Real
# Returns: torch.complex64
t = linspace_check(1j, 2)
print(t.dtype, t.device)
t = linspace_check(1j, 2, device='cuda')
print(t.dtype, t.device)
print('*'*8)
# Start: Real
# End: Imag
# Returns: torch.complex64
t = linspace_check(2, 2j)
print(t.dtype, t.device)
t = linspace_check(2, 2j, device='cuda')
print(t.dtype, t.device)
print('*'*8)
# Start: Imag
# End: Imag
# Returns: torch.complex64
t = linspace_check(1j, 2j)
print(t.dtype, t.device)
t = linspace_check(1j, 2j, device='cuda')
print(t.dtype, t.device)
print('*'*8)
# Start: Real
# End: Real
# Dtype: torch.complex64
# Returns: torch.complex64
t = linspace_check(1, 2, dtype=torch.complex64)
print(t.dtype, t.device)
t = linspace_check(1, 2, dtype=torch.complex64, device='cuda')
print(t.dtype, t.device)
print('*'*8)
# Start: Imag
# End: Imag
# Dtype: torch.float
# Returns: torch.complex64
t = linspace_check(1j, 2j, dtype=torch.float)
print(t.dtype, t.device)
t = linspace_check(1j, 2j, dtype=torch.float, device='cuda')
print(t.dtype, t.device)
print('*'*8)
# Start: Imag
# End: Imag
# out: torch.float
# Returns: Error
# out = torch.zeros(1, dtype=torch.float)
# t = linspace_check(1j, 2j, out=out)
# print(t.dtype, t.device)
# out = out.cuda()
# t = linspace_check(1j, 2j, out=out)
# print(t.dtype, t.device)
# print('*'*8)
# Start: Real
# End: Real
# out: torch.complex64
# Returns: torch.complex64
out = torch.zeros(1, dtype=torch.complex64)
t = linspace_check(1, 2, out=out)
print(t.dtype, t.device)
out = out.cuda()
t = linspace_check(1, 2, out=out)
print(t.dtype, t.device)
print('*'*8)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you; that's very helpful. Which case(s) throw the warning, imag x imag with dtype=float?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yup. Warning will be thrown if either start or end is complex and dtype is not complex.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we distinguish between dtype=float and a call without dtype set? That is, if start and end are complex and dtype is not set, does it still warn?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unfortunately not. I think codegen makes assumption that if the dtype is not passed for factory functions, it sets it to the default dtype. So the optional dtype this function receives is always set.
Reference: #38875 (comment)
pytorch/torch/csrc/utils/python_arg_parser.h
Lines 446 to 452 in e29d847
| inline at::ScalarType PythonArgs::scalartype(int i) { | |
| if (!args[i]) { | |
| auto scalartype = signature.params[i].default_scalartype; | |
| return (scalartype == at::ScalarType::Undefined) ? | |
| torch::tensors::get_default_scalar_type() : scalartype; | |
| } | |
| PyObject *obj = args[i]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's what I thought. I had to implement the custom binding for full in PR #34709 to detect the distinction, and you've correctly identified it above.
Throwing a warning when start and end are complex seems odd. I think we should pursue updating the linspace and logspace bindings to be like full and then add a test that a warning isn't thrown and that a complex start or end errors out when given a float or out dtype.
What do you think, @kshitij12345?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree. This behaviour is very odd. However would it be ok to pursue it in a follow-up PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, but let's add a test that the warning IS thrown, file an issue, and add a comment linking that issue above the test and in the code.
torch/_torch_docs.py
Outdated
| {out} | ||
| {dtype} | ||
| dtype (torch.dtype, optional) : the desired data type of returned tensor. Default: if None, | ||
| and start and end are real values, uses a global default (see torch.get_default_tensor_type()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should reference set_default_dtype:
However I would just ignore that expert option and document this as follows:
dtype (torch.dtype, optional) : the desired data type of returned tensor.
dtype (torch.dtype, optional): the data type to perform the computation in. Defaults to float when both :attr:`start` and :attr:`end` are real, and complex float when either is complex.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From offline discussion with @anjali411: referencing get_default_dtype() would be better than set_default_dtype(), and it's OK to reference get_default_dtype() like torch.get_default_tensor_type() is referenced above.
torch/_torch_docs.py
Outdated
| Keyword arguments: | ||
| {out} | ||
| {dtype} | ||
| dtype (torch.dtype, optional) : the desired data type of returned tensor. Default: if None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See above comment
mruberry
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @kshitij12345! Overall this looks good. I have a few questions/suggestions for your review. It'll be nice to have linspace and logspace work properly with complex inputs.
|
@kshitij12345 thanks for filing the follow-up issue! This PR looks great to me but I'll wait for @mruberry to ok it and ensure his comments are addressed! Also, feel free to add me as a reviewer for your follow-up PR! |
mruberry
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice work as usual, @kshitij12345.
This looks OK to go in. It'd be nice to get the link to the issue in TensorFactories.cpp for updating the warning behavior, but we don't need to block on that.
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@anjali411 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
|
@anjali411 merged this pull request in e9d7137. |
…ogspace (pytorch#38875) Summary: Closes pytorch#38775, Closes pytorch#38779 TO-DO: * [x] Add Tests Quansight Tracking : q-38775, q-38779 Pull Request resolved: pytorch#38875 Reviewed By: malfet Differential Revision: D26628530 Pulled By: anjali411 fbshipit-source-id: ca4259b9f6725c4a4350f944465327169d12122e
Closes #38775, Closes #38779
TO-DO:
Quansight Tracking : q-38775, q-38779