Skip to content

Conversation

@aocsa
Copy link
Contributor

@aocsa aocsa commented Jun 26, 2020

This PR solves a bug when the user does not specify the device for the sparse tensor using the sparse_coo_tensor function but the values and indices tensor are in different CUDA device than default.

Ex:
sparse_tensor = torch.sparse_coo_tensor(sp_ind, sp_val, sp_shape).

This was solved by using the same device_index from values Tensor for indices Tensor at sparse_coo_tensor_ctor function

Resolves bug #28500.

Note that torch.sparse.FloatTensor constructor worked as expected because it uses the legacy_sparse_tensor_ctor. This function uses the same device information of the original Tensor PyThon object. In this fix we replicated this behavior in the new constructor sparse_coo_tensor_ctor function.

So now this works!

import torch
sp_val = torch.rand(10).to('cuda:1')
print(sp_val, sp_val.device)

sp_ind = torch.tensor([[0,1,2,3,4,5,6,7,8,9], [0,1,2,3,4,5,6,7,8,9]], device='cuda:1')
print(sp_ind)

x = torch.sparse_coo_tensor(sp_ind, sp_val, (10, 10))

@aocsa aocsa requested a review from pearu June 26, 2020 23:44
@dr-ci
Copy link

dr-ci bot commented Jun 26, 2020

💊 CI failures summary and remediations

As of commit 2c84beb (more details on the Dr. CI page):


💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group.

See how this bot performed.

This comment has been revised 51 times.

@aocsa aocsa self-assigned this Jun 30, 2020
Copy link
Collaborator

@pearu pearu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The bug fix looks good to me.

@aocsa, could you implement also the corresponding unittest?

@ngimel ngimel self-requested a review July 7, 2020 00:37
@ngimel ngimel added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jul 7, 2020
@aocsa
Copy link
Contributor Author

aocsa commented Jul 7, 2020

The bug fix looks good to me.

@aocsa, could you implement also the corresponding unittest?

Thanks @pearu, I update this PR adding the corresponding unittest when the number of available devices is greater than 1.

@aocsa aocsa force-pushed the aocsa/28500_bugfix_sparse_coo_tensor branch from bf8bc9f to 530f72a Compare July 7, 2020 23:45
Copy link
Collaborator

@pearu pearu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.
The CI failures seem to be unrelated to the changes in this PR.

@aocsa aocsa force-pushed the aocsa/28500_bugfix_sparse_coo_tensor branch from 530f72a to 839b070 Compare July 8, 2020 19:50
@ezyang ezyang requested a review from mruberry July 8, 2020 20:12
@ngimel
Copy link
Collaborator

ngimel commented Jul 8, 2020

So behavior introduced by this PR (unless I'm misreading) is:

  1. if device is specified, the sparse tensor will be created on this device.
  2. if device is unspecified, and dtype is also unspecified, the sparse tensor will be created with dtype and device of vals.
  3. if device is unspecified, but dtype is specified, the sparse tensor will be created on default device, and device of val will be ignored. Why difference with the previous case? Is it documented somewhere?

@aocsa aocsa closed this Jul 9, 2020
@mruberry
Copy link
Collaborator

mruberry commented Jul 9, 2020

Sorry for arriving late. @aocsa, did you mean to close this PR? It would be nice to fix this bug. Is there something we can help with?

@aocsa aocsa reopened this Jul 9, 2020
@aocsa
Copy link
Contributor Author

aocsa commented Jul 9, 2020

Sorry for arriving late. @aocsa, did you mean to close this PR? It would be nice to fix this bug. Is there something we can help with?

Ohh sorry @mruberry. I accidentally press the wrong button. I am still working in this issue.

@mruberry
Copy link
Collaborator

mruberry commented Jul 9, 2020

Sorry for arriving late. @aocsa, did you mean to close this PR? It would be nice to fix this bug. Is there something we can help with?

Ohh sorry @mruberry. I accidentally press the wrong button. I am still working in this issue.

No worries. I'll take a look now, then.

@mruberry
Copy link
Collaborator

I have my concerns here. If I add this exception. The code related to make sure that both values and indices are using the same device though the function return var.to(device, inferred_scalar_type, ...) will be unnecessary, and we will lose this functionality. Moreover, some unit tests rely on this functionality.

So, maybe it is better to clarify it in the TORCH.SPARSE_COO_TENSOR documentation the expected behavior for this case.

OK.

@mruberry
Copy link
Collaborator

Logic looks good but looks like you need to merge with master or rebase (to help the CI merge the branch) and fix the flake8 and clang-tidy lint issues. Once that's done we should be good to go!

@aocsa aocsa force-pushed the aocsa/28500_bugfix_sparse_coo_tensor branch from 664459e to 5534084 Compare July 20, 2020 20:34
@aocsa aocsa force-pushed the aocsa/28500_bugfix_sparse_coo_tensor branch from 5534084 to b892cce Compare July 20, 2020 21:13
@aocsa
Copy link
Contributor Author

aocsa commented Jul 20, 2020

Logic looks good but looks like you need to merge with master or rebase (to help the CI merge the branch) and fix the flake8 and clang-tidy lint issues. Once that's done we should be good to go!

I rebased this PR and fixed the flake8 and clang-tidy lint issues.
Thanks @mruberry

@ngimel
Copy link
Collaborator

ngimel commented Jul 20, 2020

This still results in unexpected behavior if dtype is specified, but device is not (because in this case type_inference is false):

In [18]: i=torch.tensor(([0], [2]), device="cuda:1", dtype=torch.long)                                                                                                                                
In [19]: v=torch.tensor([1.], device="cuda:1")                                                                                                                                                        
In [20]: t=torch.sparse_coo_tensor(i, v, dtype=torch.float64)                                                                                                                                         
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-20-09518dc143ad> in <module>
----> 1 t=torch.sparse_coo_tensor(i, v, dtype=torch.float64)
RuntimeError: Expected indices and values to be on the same device, but got: values cpu, indices: cuda:1

Default: if None, uses the current device for the default tensor type
(see :func:`torch.set_default_tensor_type`). :attr:`device` will be the CPU
for CPU tensor types and the current CUDA device for CUDA tensor types.
Default: if None, the device of indices must match device of values, otherwise an exception is raised.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

YOu should mention that if None, infers device from :attr:values

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More correctly it would be "and the sparse tensor is constructed on the same device?"

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both statements are technically correct.

@aocsa
Copy link
Contributor Author

aocsa commented Jul 20, 2020

This still results in unexpected behavior if dtype is specified, but device is not (because in this case type_inference is false):

In [18]: i=torch.tensor(([0], [2]), device="cuda:1", dtype=torch.long)                                                                                                                                
In [19]: v=torch.tensor([1.], device="cuda:1")                                                                                                                                                        
In [20]: t=torch.sparse_coo_tensor(i, v, dtype=torch.float64)                                                                                                                                         
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-20-09518dc143ad> in <module>
----> 1 t=torch.sparse_coo_tensor(i, v, dtype=torch.float64)
RuntimeError: Expected indices and values to be on the same device, but got: values cpu, indices: cuda:1

Thanks for notice me about this issue @ngimel. For this case (dtype is specified, but device is not), in the current master branch, no matter where is the indices or the values both are going to be transformed to default device (CPU). See example below.

import torch
i=torch.tensor(([0], [2]), device="cuda:0", dtype=torch.long)
v=torch.tensor([1.], device="cuda:1")           
t=torch.sparse_coo_tensor(i, v, dtype=torch.float64)    
print(t.device) # CPU

Moreover, If I understood well we should avoid internal transformation of the device source. So, I am not sure, what should be the expected behavior for this case.

@ngimel
Copy link
Collaborator

ngimel commented Jul 21, 2020

Existing behavior is very strange. The expected behavior that we should document is when device is not specified (regardless of whether dtype is specified or not) indices and values should be on the same device, and the sparse tensor will be constructed on the this device.

auto tensor = tensor_from_cuda_array_interface(data);
const auto& inferred_scalar_type = type_inference ? tensor.scalar_type() : scalar_type;
auto device = device_opt.has_value() ? *device_opt : at::Device(computeDeviceType(dispatch_key));
auto device = device_opt.has_value() ? *device_opt : at::Device(computeDeviceType(dispatch_key.first));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why in this case and the following case does the call not also include the index?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because this is the case when the tensor comes from numpy, so, the index is not needed.

Tensor indices = internal_new_from_data(legacyExtractDispatchKey(values.key_set()), kLong, r.deviceOptional(3), r.pyobject(0), false, true, false);
return at::sparse_coo_tensor(indices, values, values.options().layout(at::kSparse)).set_requires_grad(r.toBool(4));
Tensor values = internal_new_from_data(inferred_dispatch_key, inferred_scalar_type, r.deviceOptional(ARG_DEVICE), r.pyobject(ARG_VALUES), false, true, type_inference);
check_expected_devices(device_guard.original_device(), values, r.pyobject(ARG_INDICES));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about moving this check to after the construction of the indices tensor and rewriting it as `TORCH_CHECK(values.device() == indices.device(), ...)'?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All right!

const auto inferred_scalar_type = r.scalartypeWithDefault(ARG_TYPE, scalar_type);
at::OptionalDeviceGuard device_guard(r.deviceOptional(ARG_DEVICE));
Tensor values = internal_new_from_data(inferred_dispatch_key, inferred_scalar_type, r.deviceOptional(ARG_DEVICE), r.pyobject(ARG_VALUES), false, true, type_inference);
check_expected_devices(device_guard.original_device(), values, r.pyobject(ARG_INDICES));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same question as above about the position and style of this check

at::OptionalDeviceGuard device_guard(r.deviceOptional(ARG_DEVICE));
Tensor values = internal_new_from_data(inferred_dispatch_key, inferred_scalar_type, r.deviceOptional(ARG_DEVICE), r.pyobject(ARG_VALUES), false, true, type_inference);
Tensor indices = internal_new_from_data(legacyExtractDispatchKey(values.key_set()), kLong, r.deviceOptional(ARG_DEVICE), r.pyobject(ARG_INDICES), false, true, false);
check_expected_devices(device_guard.original_device(), values, r.pyobject(0));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same question as above

// are defined per-layout-type (e.g. tensor vs sparse_coo_tensor).
const auto& inferred_scalar_type = type_inference ? var.scalar_type() : scalar_type;
auto device = device_opt.has_value() ? *device_opt : (type_inference ? var.device() : at::Device(computeDeviceType(dispatch_key)));
auto device = device_opt.has_value() ? *device_opt : (type_inference ? var.device() : at::Device(computeDeviceType(dispatch_key.first), dispatch_key.second));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the existing line that needs to change to no longer cause dtype to have an affect on device, right? We really don't want there to be any connection between dtype and device.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, but it can be solved by using the right dispatch_key. For now, I will not touch this internal_new_from_data function.

@mruberry
Copy link
Collaborator

@aocsa, I had a chance to review/investigate the dtype issue more with @ngimel, and we actually found that it affects other functions, like torch.as_tensor, too, in addition to being a source of confusion for how best to address this sparse ctor issue. Sorry for that existing issue getting conflated with what you're trying to fix.

To simplify this fix and address that underlying issue, we thought we'd fix internal_new_from_data's behavior. That should separate these issues and make this much easier to address.

What are your thoughts with that approach?

@aocsa
Copy link
Contributor Author

aocsa commented Jul 21, 2020

@aocsa, I had a chance to review/investigate the dtype issue more with @ngimel, and we actually found that it affects other functions, like torch.as_tensor, too, in addition to being a source of confusion for how best to address this sparse ctor issue. Sorry for that existing issue getting conflated with what you're trying to fix.

To simplify this fix and address that underlying issue, we thought we'd fix internal_new_from_data's behavior. That should separate these issues and make this much easier to address.

What are your thoughts with that approach?

Sure, I think it's a good idea.
Actually, that was my initial approach, don't touch the internal mechanism of internal_new_from_data function due to it is used in other parts.

I am already testing new changes that only touch the sparse_coo_tensor_ctor function, choosing by default the dispatch_key device info from values tensor. I can update my PR covering the latest feedback and then wait for the internal_new_from_data fix.

cc @mruberry

@mruberry
Copy link
Collaborator

Great; I'll let you know when the fix is available.

@mruberry
Copy link
Collaborator

Fyi #41984, which is currently being tested to see if it breaks any existing PyTorch programs we know of.

facebook-github-bot pushed a commit that referenced this pull request Jul 25, 2020
…vice of inputs tensors they're given, by default (#41984)

Summary:
**BC-Breaking Note**

This PR changes the behavior of the torch.tensor, torch.as_tensor, and sparse constructors. When given a tensor as input and a device is not explicitly specified, these constructors now always infer their device from the tensor. Historically, if the optional dtype kwarg was provided then these constructors would not infer their device from tensor inputs. Additionally, for the sparse ctor a runtime error is now thrown if the indices and values tensors are on different devices and the device kwarg is not specified.

**PR Summary**
This PR's functional change is a single line:

```
auto device = device_opt.has_value() ? *device_opt : (type_inference ? var.device() : at::Device(computeDeviceType(dispatch_key)));
```
=>
```
auto device = device_opt.has_value() ? *device_opt : var.device();
```

in `internal_new_from_data`. This line entangled whether the function was performing type inference with whether it inferred its device from an input tensor, and in practice meant that

```
t = torch.tensor((1, 2, 3), device='cuda')
torch.tensor(t, dtype=torch.float64)
```

would return a tensor on the CPU, not the default CUDA device, while

```
t = torch.tensor((1, 2, 3), device='cuda')
torch.tensor(t)
```

would return a tensor on the device of `t`!

This behavior is niche and odd, but came up while aocsa was fixing #40648.

An additional side affect of this change is that the indices and values tensors given to a sparse constructor must be on the same device, or the sparse ctor must specify the dtype kwarg. The tests in test_sparse.py have been updated to reflect this behavior.

Pull Request resolved: #41984

Reviewed By: ngimel

Differential Revision: D22721426

Pulled By: mruberry

fbshipit-source-id: 909645124837fcdf3d339d7db539367209eccd48
@mruberry
Copy link
Collaborator

#41984 is in!

@aocsa, an unintended side affect of that change is that it may also have fixed this spare issue. Would you please review? It now appears like sparse tensors do have to be on the same device if the device kwarg isn't filled out.

@aocsa
Copy link
Contributor Author

aocsa commented Jul 27, 2020

#41984 is in!

@aocsa, an unintended side affect of that change is that it may also have fixed this spare issue. Would you please review? It now appears like sparse tensors do have to be on the same device if the device kwarg isn't filled out.

Great!, I already reviewed and tested changes in PR #41984 with the identified test cases and it solves this issue too. So this PR should be closed, right?
cc @mruberry

@mruberry
Copy link
Collaborator

#41984 is in!
@aocsa, an unintended side affect of that change is that it may also have fixed this spare issue. Would you please review? It now appears like sparse tensors do have to be on the same device if the device kwarg isn't filled out.

Great!, I already reviewed and tested changes in PR #41984 with the identified test cases and it solves this issue too. So this PR should be closed, right?
cc @mruberry

If you've verified the behavior then I think we're OK closing this, yes. I was surprised the change we identified as necessary here also solved this issue. Thank you for helping with it.

@mruberry mruberry closed this Jul 27, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants