Skip to content

Conversation

@kshitij12345
Copy link
Collaborator

@kshitij12345 kshitij12345 commented May 21, 2020

Closes #38775, Closes #38779

TO-DO:

  • Add Tests

Quansight Tracking : q-38775, q-38779

@anjali411 anjali411 added the module: complex Related to complex number support in PyTorch label May 21, 2020
@anjali411 anjali411 self-requested a review May 21, 2020 16:17
TORCH_CHECK(steps >= 0, "number of steps must be non-negative");
Tensor result = at::empty({steps}, options);
auto result_options = options;
if (start.isComplex() || end.isComplex()){
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this behavior is not stated in the doc so doc should be updated as well. perhaps something on the lines:

dtype (torch.dtype, optional) – the desired data type of returned tensor. Default: if None, and start and end are real values, uses a global default (see torch.set_default_tensor_type()) else sets it to the complex type corresponding to the torch.get_default_tensor_type().

also as I mentioned in the issue, we should update the doc to state that start and can be float or complex

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure.

@ailzhang ailzhang added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label May 21, 2020
const TensorOptions& options) {
TORCH_CHECK(steps >= 0, "number of steps must be non-negative");
Tensor result = at::empty({steps}, options);
auto result_options = options;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also we should error out when the input dtype is not complex, but the start and end are. this is tricky here because the options would by default have a floating point dtype.
cc. @mruberry

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which input dtype, the function only takes start, stop and steps as the main input to compute.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Understood now, please check updated the comment below this line and let me know if there is a graceful way to do it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a way to know if it was defaulted or user specified?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately I don't think there is, @bhosmer to verify. In a similar circumstance I had to implement a custom dispatch path (see #34709).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh. Thanks for the reference. Will have a look.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks @mruberry I think we can use a similar approach here as well. in fact, it might be useful to generalize the infer_full_options to make it usable at other places like here as well.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure if it can be done in a very generic way as ops do differ in their behaviour of inferring the type like the inputs that they consider, the types they default to.
(However, I don't know much about the same so my assumption above might be wrong).

For now have put the inference code as a function similar to infer_full_options.

@dr-ci
Copy link

dr-ci bot commented May 21, 2020

💊 CI failures summary and remediations

As of commit ea80fc6 (more details on the Dr. CI page):


  • 1/2 failures possibly* introduced in this PR
    • 1/1 non-scanned failure(s)
  • 1/2 broken upstream at merge base f448c59 from Mar 01 until Mar 02

🚧 1 fixed upstream failure:

These were probably caused by upstream breakages that were already fixed.

Please rebase on the viable/strict branch (expand for instructions)

If your commit is older than viable/strict, run these commands:

git fetch https://github.com/pytorch/pytorch viable/strict
git rebase FETCH_HEAD

ci.pytorch.org: 1 failed


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions to the (internal) Dr. CI Users group.

start = torch.randn(1, dtype=dtype).item()
end = (start + torch.randn(1, dtype=dtype) + random.randint(5, 15)).item()

# Crashing for the step values: [2, 3, 5, 11, 256, 257, 2**22]
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a suspicion on this line ( as the mismatch was past the half the elements).

data_ptr[i] = std::pow(scalar_base, scalar_end - (step * static_cast<scalar_t>(steps - i - 1)));

@kshitij12345
Copy link
Collaborator Author

kshitij12345 commented May 28, 2020

I am not sure how and where the below part is generated.

There looks to be a discrepancy in how the dtype=None is handled in full vs {lin/log}space,
where options.has_dtype() in C++ Operator in TensorFactories, correctly returns true or false based on whether user has passed the dtype or not.
However for {lin/log}space, options.has_dtype() always returns true as the default type is set in the generated binding.

From what I see in torch/csrc/autograd/generated/python_torch_functions.cpp in the below snippet,
.dtype(_r.scalartype(4)) is set to the default type and hence options.has_dtype() always returns true

// linspace
static PyObject * THPVariable_linspace(PyObject* self_, PyObject* args, PyObject* kwargs)
{
  HANDLE_TH_ERRORS
  static PythonArgParser parser({
    "linspace(Scalar start, Scalar end, int64_t steps=100, *, Tensor out=None, ScalarType dtype=None, Layout layout=torch.strided, Device device=None, bool pin_memory=False, bool requires_grad=False)",
  }, /*traceable=*/true);

  ParsedArgs<9> parsed_args;
  auto _r = parser.parse(args, kwargs, parsed_args);
  if(_r.has_torch_function()) {
    return handle_torch_function(_r, args, kwargs, THPVariableFunctionsModule, "torch");
  }
  if (_r.isNone(3)) {
    // aten::linspace(Scalar start, Scalar end, int steps=100, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
    const auto options = TensorOptions()
        .dtype(_r.scalartype(4))
        .device(_r.device(6))
        .layout(_r.layoutOptional(5))
        .requires_grad(_r.toBool(8))
        .pinned_memory(_r.toBool(7));
    torch::utils::maybe_initialize_cuda(options);
    auto dispatch_linspace = [](Scalar start, Scalar end, int64_t steps, const TensorOptions & options) -> Tensor {
      pybind11::gil_scoped_release no_gil;
      return torch::linspace(start, end, steps, options);
    };
    return wrap(dispatch_linspace(_r.scalar(0), _r.scalar(1), _r.toInt64(2), options));
  } else {
    // aten::linspace.out(Scalar start, Scalar end, int steps=100, *, Tensor(a!) out) -> Tensor(a!)
    check_out_type_matches(_r.tensor(3), _r.scalartype(4),
                           _r.isNone(4), _r.layoutOptional(5),
                           _r.device(6), _r.isNone(6));
    auto dispatch_linspace_out = [](Tensor out, Scalar start, Scalar end, int64_t steps) -> Tensor {
      pybind11::gil_scoped_release no_gil;
      return at::linspace_out(out, start, end, steps);
    };
    return wrap(dispatch_linspace_out(_r.tensor(3), _r.scalar(0), _r.scalar(1), _r.toInt64(2)).set_requires_grad(_r.toBool(8)));
  }
  Py_RETURN_NONE;
  END_HANDLE_TH_ERRORS
}

However for full

static PyObject * THPVariable_full(PyObject* self, PyObject* args, PyObject* kwargs) {
  HANDLE_TH_ERRORS

  static PythonArgParser parser({
    "full(IntArrayRef size, Scalar fill_value, *, Tensor out=None, ScalarType dtype=None, Layout layout=torch.strided, Device device=None, bool pin_memory=False, bool requires_grad=False)",
    "full(IntArrayRef size, Scalar fill_value, *, DimnameList names=None, ScalarType dtype=None, Layout layout=torch.strided, Device device=None, bool pin_memory=False, bool requires_grad=False)",
  }, /*traceable=*/true);

  // Acquires (common) arguments
  ParsedArgs<8> parsed_args;
  auto r = parser.parse(args, kwargs, parsed_args);

  if(r.has_torch_function()) {
    return handle_torch_function(r, args, kwargs, THPVariableFunctionsModule, "torch");
  }

  auto size = r.intlist(0);
  auto fill_val = r.scalar(1);
  const auto options = TensorOptions{}
      .dtype(r.scalartypeOptional(3))
      .layout(r.layout(4))
      .device(r.device(5))
      .pinned_memory(r.toBool(6))
      .requires_grad(r.toBool(7));

  if (r.idx == 0) {
    // full
    if (r.isNone(2)) {
      return wrap(dispatch_full(size, fill_val, options).set_requires_grad(r.toBool(7)));
    }

    // full.out
    // Validates out tensor and other kwargs
    auto result = r.tensor(2);
    TORCH_CHECK(!r.toBool(6), " `pin_memory` and `out` parameters are incompatible");
    check_out_type_matches(result, r.scalartype(3), r.isNone(3), r.layout(4),
                          r.device(5), r.isNone(5));

    return wrap(dispatch_full(size, fill_val, result).set_requires_grad(r.toBool(7)));
  } else if (r.idx == 1) {
    // full.names
    if (r.isNone(2)) {
      return wrap(dispatch_full(size, fill_val, c10::nullopt, options).set_requires_grad(r.toBool(7)));
    }

    // Converts from c10::optional<std:vector...> to c10::optional<ArrayRef...>
    auto raw_names = r.toDimnameListOptional(2);
    c10::optional<DimnameList> names(*raw_names);
    return wrap(dispatch_full(size, fill_val, names, options).set_requires_grad(r.toBool(7)));
  }

  Py_RETURN_NONE;
  END_HANDLE_TH_ERRORS
}

.dtype(r.scalartypeOptional(3)) which actually allows the C++ Operator to correctly query options.has_dtype().

@anjali411
Copy link
Contributor

@kshitij12345 let's limit the scope of PR to support complex only when the input dtype is complex. We can come back to this later when the default dtype issue is fixed more generally for all the other ops as well.

@anjali411
Copy link
Contributor

also would you like to work index_select port from TH to ATen #24578. it's hi pri because it's used in prod_backward as well as torchaudio code.

@kshitij12345
Copy link
Collaborator Author

@kshitij12345 let's limit the scope of PR to support complex only when the input dtype is complex. We can come back to this later when the default dtype issue is fixed more generally for all the other ops as well.

Agreed.

also would you like to work index_select port from TH to ATen #24578. it's hi pri because it's used in prod_backward as well as torchaudio code.

Sure will try to give it a shot over the weekend.

@kshitij12345
Copy link
Collaborator Author

kshitij12345 commented Feb 23, 2021

@anjali411 @mruberry
Apologies that this got dropped under the radar.
PTAL :)

Note: Lint error is irrelevant. Will rebase if no-more changes are necessary.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@anjali411 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

{dtype}
dtype (torch.dtype, optional) : the desired data type of returned tensor. Default: if None,
and start and end are real values, uses a global default (see torch.get_default_tensor_type())
else sets it to the complex default type.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For clarity, I think we could re-write "the complex default type" as "the complex type corresponding to the global default type".

{dtype}
dtype (torch.dtype, optional) : the desired data type of returned tensor. Default: if None,
and start and end are real values, uses a global default (see torch.get_default_tensor_type())
else sets it to the complex default type.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as below

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@anjali411 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

self.assertTrue(t[steps - 1].item() == a[steps - 1])

def test_linspace_vs_numpy_complex(self, device):
dtype = torch.complex64
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Accept a dtype argument use the @dtypes decorator with just torch.complex64

self.assertEqual(t[steps - 1], a[steps - 1])

def test_logspace_vs_numpy_complex(self, device):
dtype = torch.complex64
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Analogous comments here as in previous test - seems like the tests could be combined?

Scalar end,
const TensorOptions& options) {
auto result_options = options;
if (start.isComplex() || end.isComplex()) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems odd. Would you elaborate on what happens in the following cases:

complex start, float end
float start, complex end
complex start, complex end
float start, float end, complex dtype or out
complex start, complex end, float dtype or out
dtype and out tensor type mismatch

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

import torch

def linspace_check(start, end, dtype=None,  device='cpu', out=None):
    if out is None:
        return torch.linspace(start, end, dtype=dtype, device=device)
    if out is not None:
        return torch.linspace(start, end, out=out)

# Start: Imag
# End: Real
# Returns: torch.complex64
t = linspace_check(1j, 2)
print(t.dtype, t.device)
t = linspace_check(1j, 2, device='cuda')
print(t.dtype, t.device)
print('*'*8)

# Start: Real
# End: Imag
# Returns: torch.complex64
t = linspace_check(2, 2j)
print(t.dtype, t.device)
t = linspace_check(2, 2j, device='cuda')
print(t.dtype, t.device)
print('*'*8)

# Start: Imag
# End: Imag
# Returns: torch.complex64
t = linspace_check(1j, 2j)
print(t.dtype, t.device)
t = linspace_check(1j, 2j, device='cuda')
print(t.dtype, t.device)
print('*'*8)

# Start: Real
# End: Real
# Dtype: torch.complex64
# Returns: torch.complex64
t = linspace_check(1, 2, dtype=torch.complex64)
print(t.dtype, t.device)
t = linspace_check(1, 2, dtype=torch.complex64, device='cuda')
print(t.dtype, t.device)
print('*'*8)

# Start: Imag
# End: Imag
# Dtype: torch.float
# Returns: torch.complex64
t = linspace_check(1j, 2j, dtype=torch.float)
print(t.dtype, t.device)
t = linspace_check(1j, 2j, dtype=torch.float, device='cuda')
print(t.dtype, t.device)
print('*'*8)

# Start: Imag
# End: Imag
# out: torch.float
# Returns: Error
# out = torch.zeros(1, dtype=torch.float)
# t = linspace_check(1j, 2j, out=out)
# print(t.dtype, t.device)
# out = out.cuda()
# t = linspace_check(1j, 2j, out=out)
# print(t.dtype, t.device)
# print('*'*8)

# Start: Real
# End: Real
# out: torch.complex64
# Returns: torch.complex64
out = torch.zeros(1, dtype=torch.complex64)
t = linspace_check(1, 2, out=out)
print(t.dtype, t.device)
out = out.cuda()
t = linspace_check(1, 2, out=out)
print(t.dtype, t.device)
print('*'*8)

Copy link
Collaborator

@mruberry mruberry Mar 2, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you; that's very helpful. Which case(s) throw the warning, imag x imag with dtype=float?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup. Warning will be thrown if either start or end is complex and dtype is not complex.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we distinguish between dtype=float and a call without dtype set? That is, if start and end are complex and dtype is not set, does it still warn?

Copy link
Collaborator Author

@kshitij12345 kshitij12345 Mar 3, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately not. I think codegen makes assumption that if the dtype is not passed for factory functions, it sets it to the default dtype. So the optional dtype this function receives is always set.

Reference: #38875 (comment)

inline at::ScalarType PythonArgs::scalartype(int i) {
if (!args[i]) {
auto scalartype = signature.params[i].default_scalartype;
return (scalartype == at::ScalarType::Undefined) ?
torch::tensors::get_default_scalar_type() : scalartype;
}
PyObject *obj = args[i];

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's what I thought. I had to implement the custom binding for full in PR #34709 to detect the distinction, and you've correctly identified it above.

Throwing a warning when start and end are complex seems odd. I think we should pursue updating the linspace and logspace bindings to be like full and then add a test that a warning isn't thrown and that a complex start or end errors out when given a float or out dtype.

What do you think, @kshitij12345?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. This behaviour is very odd. However would it be ok to pursue it in a follow-up PR?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, but let's add a test that the warning IS thrown, file an issue, and add a comment linking that issue above the test and in the code.

{out}
{dtype}
dtype (torch.dtype, optional) : the desired data type of returned tensor. Default: if None,
and start and end are real values, uses a global default (see torch.get_default_tensor_type())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should reference set_default_dtype:

https://pytorch.org/docs/stable/generated/torch.set_default_dtype.html#torch-set-default-dtype

However I would just ignore that expert option and document this as follows:

dtype (torch.dtype, optional) : the desired data type of returned tensor.

dtype (torch.dtype, optional): the data type to perform the computation in. Defaults to float when both :attr:`start` and :attr:`end` are real, and complex float when either is complex.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From offline discussion with @anjali411: referencing get_default_dtype() would be better than set_default_dtype(), and it's OK to reference get_default_dtype() like torch.get_default_tensor_type() is referenced above.

Keyword arguments:
{out}
{dtype}
dtype (torch.dtype, optional) : the desired data type of returned tensor. Default: if None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See above comment

Copy link
Collaborator

@mruberry mruberry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @kshitij12345! Overall this looks good. I have a few questions/suggestions for your review. It'll be nice to have linspace and logspace work properly with complex inputs.

@anjali411
Copy link
Contributor

@kshitij12345 thanks for filing the follow-up issue! This PR looks great to me but I'll wait for @mruberry to ok it and ensure his comments are addressed!

Also, feel free to add me as a reviewer for your follow-up PR!

Copy link
Collaborator

@mruberry mruberry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work as usual, @kshitij12345.

This looks OK to go in. It'd be nice to get the link to the issue in TensorFactories.cpp for updating the warning behavior, but we don't need to block on that.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@anjali411 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@anjali411 merged this pull request in e9d7137.

@kshitij12345 kshitij12345 deleted the fix/linspace-logspace branch March 5, 2021 16:51
nikithamalgifb added a commit that referenced this pull request Mar 5, 2021
Summary:
Closes #38775, Closes #38779

TO-DO:
* [x] Add Tests

Quansight Tracking : q-38775, q-38779

Reviewed By: malfet

Differential Revision: D26628530

Pulled By: anjali411

fbshipit-source-id: ca4259b9f6725c4a4350f944465327169d12122e

[ghstack-poisoned]
nikithamalgifb added a commit that referenced this pull request Mar 5, 2021
Summary:
Closes #38775, Closes #38779

TO-DO:
* [x] Add Tests

Quansight Tracking : q-38775, q-38779

Reviewed By: malfet

Differential Revision: D26628530

Pulled By: anjali411

fbshipit-source-id: ca4259b9f6725c4a4350f944465327169d12122e

[ghstack-poisoned]
nikithamalgifb added a commit that referenced this pull request Mar 5, 2021
Summary:
Closes #38775, Closes #38779

TO-DO:
* [x] Add Tests

Quansight Tracking : q-38775, q-38779

Reviewed By: malfet

Differential Revision: D26628530

Pulled By: anjali411

fbshipit-source-id: ca4259b9f6725c4a4350f944465327169d12122e

[ghstack-poisoned]
xsacha pushed a commit to xsacha/pytorch that referenced this pull request Mar 31, 2021
…ogspace (pytorch#38875)

Summary:
Closes pytorch#38775, Closes pytorch#38779

TO-DO:
* [x] Add Tests

Quansight Tracking : q-38775, q-38779

Pull Request resolved: pytorch#38875

Reviewed By: malfet

Differential Revision: D26628530

Pulled By: anjali411

fbshipit-source-id: ca4259b9f6725c4a4350f944465327169d12122e
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed Merged module: complex Related to complex number support in PyTorch open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Logspace broken for complex dtypes Linspace broken for complex dtypes

6 participants