Skip to content

Conversation

@vishwakftw
Copy link
Contributor

@vishwakftw vishwakftw commented Jul 8, 2018

If this is good, I could write some tests to ensure collision doesn't occur within a given range.

Closes #7228

}
case at::Device::Type::CUDA: {
if (self->device.index() > 254) {
AT_WARN("Device indices of > 254 might result in non-deterministic hashes");

This comment was marked as off-topic.

This comment was marked as off-topic.

@goldsborough
Copy link
Contributor

How about implementing this at the level of ATen, with a specialization of std::hash for at::Device? Then the python-specific function could just return std::hash<at::Device>{}(device). It is a tiny bit more effort, but I could imagine we'd want to hash at::Device in C++ at one point or another, and then we'd have to move this code into ATen anyway. If you're comfortable with building/modifying ATen I would suggest this route, otherwise we can change this in another PR at a later point in time.

int64_t operator()(const at::Device& device) const noexcept {
int64_t hash_val = static_cast<int64_t>(device.index());
if (device.type() == at::Device::Type::CUDA) {
hash_val += 2;

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

@ezyang
Copy link
Contributor

ezyang commented Jul 9, 2018

@pytorchbot retest this please

namespace std {
template<> struct hash<at::Device>
{
int64_t operator()(const at::Device& device) const noexcept {

This comment was marked as off-topic.

{
int64_t operator()(const at::Device& device) const noexcept {
int64_t hash_val = static_cast<int64_t>(device.index());
if (device.type() == at::Device::Type::CUDA) {

This comment was marked as off-topic.

@weiyangfb
Copy link
Contributor

@pytorchbot retest this please

1 similar comment
@vishwakftw
Copy link
Contributor Author

@pytorchbot retest this please

}

PyObject *THPDevice_hash(THPDevice *self)
static size_t THPDevice_hash(THPDevice *self)

This comment was marked as off-topic.

END_HANDLE_TH_ERRORS
}

static size_t THPDevice_hash(THPDevice *self)

This comment was marked as off-topic.

@apaszke
Copy link
Contributor

apaszke commented Jul 16, 2018

Actually a final improvement we could do would be to do a range check on the std::hash output before we static_cast it. This will prevent overflow errors, which can cause UB and make the hash function not conform to its invariants (e.g. hash of a single object can be different every time). You can use std::numeric_limits to find out what's the max value for Py_ssize_t (we're good with min because it's a signed type, so it's surely not larger than 0).

Sorry, I didn't think of that before 😕

@vishwakftw
Copy link
Contributor Author

I think this should be protected by the way the hash is designed. It is merely adding 3 to the device index, which is upper bounded by the limit of int32_t. I don't think there can be a situation where the hash could overflow, but I could add it, if you deem it necessary.

@apaszke
Copy link
Contributor

apaszke commented Jul 16, 2018

@vishwakftw yes, the current code is correct for sure, but what if for some reason we'll end up using hash_combine of multiple fields, and std::hash<int64_t>(index) in the future (e.g. because there will be more backends)? Who will remember to go and update the Python bindings because they have silent assumptions about the code that lives so far. The bug won't even be very apparent and will be very hard to debug.

@vishwakftw
Copy link
Contributor Author

Thank you for the detailed explanation. Should something like this do:

if (hash_val > std::numeric_limits<Py_ssize_t>::max()) {  // hash_val is size_t
  throw std::runtime_error("Hash value limit exceeded, can overflow");
}

@apaszke
Copy link
Contributor

apaszke commented Jul 16, 2018

No, that's kind of hard to fix for the user. Let's just use modulo arithmetic to limit the range.

@vishwakftw
Copy link
Contributor Author

Are you suggesting setting an upper bound for the device index, like in the issue ?

  return static_cast<Py_ssize_t>(std::hash<at::Device>{}(self->device) % MAX_HASH);

If this is how you want it done, where should I add MAX_HASH?

Sorry about too many questions.

@apaszke
Copy link
Contributor

apaszke commented Jul 16, 2018

Yeah, just use the numeric_limits to find out what's the maximal value. No worries, I'm happy to help.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ezyang has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@vishwakftw vishwakftw deleted the device-hash branch July 18, 2018 00:19
zdevito pushed a commit to zdevito/ATen that referenced this pull request Jul 18, 2018
Summary:
If this is good, I could write some tests to ensure collision doesn't occur within a given range.

Closes #7228
Pull Request resolved: pytorch/pytorch#9246

Differential Revision: D8872608

Pulled By: ezyang

fbshipit-source-id: 0ed29a73188f4167b42756f59a5c9a3d5cb37326
jramseyer pushed a commit to jramseyer/pytorch that referenced this pull request Jul 30, 2018
Summary:
If this is good, I could write some tests to ensure collision doesn't occur within a given range.

Closes pytorch#7228
Pull Request resolved: pytorch#9246

Differential Revision: D8872608

Pulled By: ezyang

fbshipit-source-id: 0ed29a73188f4167b42756f59a5c9a3d5cb37326
goodlux pushed a commit to goodlux/pytorch that referenced this pull request Aug 15, 2018
Summary:
If this is good, I could write some tests to ensure collision doesn't occur within a given range.

Closes pytorch#7228
Pull Request resolved: pytorch#9246

Differential Revision: D8872608

Pulled By: ezyang

fbshipit-source-id: 0ed29a73188f4167b42756f59a5c9a3d5cb37326
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[feature request] hash of torch.Device

7 participants