Skip to content

Conversation

@zhaoyanpeng
Copy link

tensor.exponential_() on cuda device may generate 0.0; below is the code for reproducing the error.

import torch
torch.manual_seed(0) 
torch.cuda.set_device(0)

cnt = 0 
while True:
    randns = torch.empty((10000000,), device=torch.cuda.current_device()).exponential_() 
    #randns = torch.empty((10000000,)).exponential_() 
    gumbel = -randns.log() 

    cnt += 1
    idxes = torch.isinf(gumbel)
    if idxes.any():
        _, idx = torch.max(idxes, 0)
        print('{} is sampled in the {}-th entry in the {}-th sampling'.format(randns[idx], idx, cnt))
        break
    else:
        print('{}'.format(cnt))

output: -0.0 is sampled in the 376731-th entry in the 1-th sampling

@pytorchbot pytorchbot added the module: nn Related to torch.nn label May 6, 2019
@ezyang ezyang requested review from ezyang and syed-ahmed May 17, 2019 15:34
@ezyang
Copy link
Contributor

ezyang commented May 17, 2019

@syed-ahmed Do you think exponential is just wrong for doing this?

@zhaoyanpeng Do you think you could add your sample code here to the test suite as a slowTest? Grep for examples of using this decorator.

@zhaoyanpeng
Copy link
Author

@ezyang a test function has been added. I was not able to reproduce the error using the test function, but it shall be noted that it is possible to sample a zero from the exponential distribution in theory.

@zhaoyanpeng
Copy link
Author

0 can be sampled from

return -(-u).log1p() / self.rate

where u \in [0, 1).

@D-X-Y
Copy link

D-X-Y commented May 18, 2019

I have the same question, and can reproduce this problem. Below are my own script, in some cases, the gumbels variable will be inf, and causes probs being nan, and then multinomial will raise an error.

import torch
import torch.nn as nn

class TEST(nn.Module):
  def __init__(self):
    super(TEST, self).__init__()
    self.register_parameter('weight', nn.Parameter(torch.Tensor(100, 9)))
    nn.init.normal_(self.weight, 0, 0.01)

  def forward(self, inputs):
    logits  = self.weight
    nn.init.normal_(logits, 0, 0.01)
    gumbels = -torch.empty_like(logits).exponential_().log()
    new_logits = (logits + gumbels) / 0.5
    probs = nn.functional.softmax(new_logits, dim=1).cpu()
    selected_index = torch.multinomial(probs + 1e-7, 2, False).to(logits.device)
    
test = TEST()
test = nn.DataParallel(test).cuda()
inputs = torch.Tensor(4, 5)

for i in range(100000):
  test(inputs)

@syed-ahmed
Copy link
Collaborator

@syed-ahmed Do you think exponential is just wrong for doing this?

For CUDA, currently we are doing exponential like this:

GENERATE_KERNEL1(generate_exponential, float, double lambda, float, curand_uniform, (float)(-1. / lambda * log(x)))	
GENERATE_KERNEL1(generate_exponential, double, double lambda, double, curand_uniform_double, (double)(-1. / lambda * log(x)))	

 GENERATE_KERNEL1(generate_exponential, at::Half, double lambda, float, curand_uniform, (ScalarConvert<float, at::Half>::to((float)(-1. / lambda * log(x)))))

curand_uniform's range is (0, 1]. So it's totally possible that exponential is giving 0.0 when log gets 1. In the CPU side, we don't see it because, uniform's range there is [0,1) and we do a log(u-1) in the exponential formula to not give it zero. As a result exponential on the CPU side has bounds of (0,1), which is also the state bounds in wikipedia. One workaround I could add for the CUDA side in this PR, is to check for 1 and replace it with std::nextafter(1.0, 0.0).

@syed-ahmed
Copy link
Collaborator

@ezyang a test function has been added. I was not able to reproduce the error using the test function, but it shall be noted that it is possible to sample a zero from the exponential distribution in theory.

@zhaoyanpeng May be the global manual_seed in the test suite is not letting you reproduce the 0s. May be add the torch.manual_seed(0) from your reproducer script in your tests functions and see if that reproduces?

@zhaoyanpeng
Copy link
Author

zhaoyanpeng commented May 22, 2019

@ezyang a test function has been added. I was not able to reproduce the error using the test function, but it shall be noted that it is possible to sample a zero from the exponential distribution in theory.

@zhaoyanpeng May be the global manual_seed in the test suite is not letting you reproduce the 0s. May be add the torch.manual_seed(0) from your reproducer script in your tests functions and see if that reproduces?

@syed-ahmed Actually I copied the sample code above to the test function at first, but the error did not occur. I just tested it again, still could not reproduce the error.

@ezyang
Copy link
Contributor

ezyang commented May 22, 2019

@pytorchbot rebase this please

@ezyang
Copy link
Contributor

ezyang commented May 24, 2019

We should be reseeding on every test case, but if we need a specific seed for this test, it seems fine to explicitly reseed.

_test_gumbel_isinf is dead code, FYI.

I think ideally what we'd want is something like:

  1. A test to check that within N samples we never see log(0.0)
  2. A test to test that, for our seed, in N samples, you do get a random value that is log(0.0) (so that if we ever change the RNG algorithm, we can make sure the test is still working)

@zhaoyanpeng
Copy link
Author

We should be reseeding on every test case, but if we need a specific seed for this test, it seems fine to explicitly reseed.

_test_gumbel_isinf is dead code, FYI.

I think ideally what we'd want is something like:

  1. A test to check that within N samples we never see log(0.0)
  2. A test to test that, for our seed, in N samples, you do get a random value that is log(0.0) (so that if we ever change the RNG algorithm, we can make sure the test is still working)

I did that on purpose since I do not think _test_gumbel_isinf is necessary. As @syed-ahmed said, it is possible that exponential on the GPU side may produce a zero, which is, however, should not be allowed when using it to generate gumbels. Therefore, I think the correct logic is avoiding having log(0.) in implementation, instead of dedicating a test to check an error which is possible to occur.

I explored a bit more about the possible zero exponential. I prepared two pytorch environments. The first is locally installed pytorch with Conda; the second is pytorch source code. Both of them use v1.1.0 version. I run the sample code above in the first environment and got a zero for the 376731-th exponential. With the same seed and device as in the first environment, I modified test_nn.py in the second environment and run the sample code from test_nn.TestNN, which gave a 2.782326230370438e-08 for the 376731-th exponential. I compared the sequences of exponentials generated in the two environments. They differ by about 1e-7. Any idea about why the difference exists?

By the way, would it be better if we had a function gumbel_ like exponential_?

@ezyang
Copy link
Contributor

ezyang commented Jun 5, 2019

@pytorchbot rebase this please

@ezyang
Copy link
Contributor

ezyang commented Jun 7, 2019

This PR has run afoul of some TorchScript changes on master.

Jun 05 16:42:59 ======================================================================
Jun 05 16:42:59 ERROR: test_nn_gumbel_softmax (__main__.TestJitGeneratedFunctional)
Jun 05 16:42:59 ----------------------------------------------------------------------
Jun 05 16:42:59 Traceback (most recent call last):
Jun 05 16:42:59   File "test_jit.py", line 14962, in wrapper
Jun 05 16:42:59     return fn(*args, **kwargs)
Jun 05 16:42:59   File "test_jit.py", line 15012, in do_test
Jun 05 16:42:59     run_test()
Jun 05 16:42:59   File "test_jit.py", line 15004, in run_test
Jun 05 16:42:59     check_against_reference(self, script_fn, fn, f_args_variable, kwargs_variable, no_grad=no_grad)
Jun 05 16:42:59   File "test_jit.py", line 14242, in check_against_reference
Jun 05 16:42:59     outputs_test = self.runAndSaveRNG(func, nograd_inputs, kwargs)
Jun 05 16:42:59   File "test_jit.py", line 682, in runAndSaveRNG
Jun 05 16:42:59     results = func(*inputs, **kwargs)
Jun 05 16:42:59   File "test_jit.py", line 14192, in script_fn
Jun 05 16:42:59     CU = torch.jit.CompilationUnit(script)
Jun 05 16:42:59   File "/opt/conda/lib/python3.6/site-packages/torch/jit/__init__.py", line 869, in __init__
Jun 05 16:42:59     self.define(lang, _frames_up=_frames_up + 1)
Jun 05 16:42:59   File "/opt/conda/lib/python3.6/site-packages/torch/jit/__init__.py", line 874, in define
Jun 05 16:42:59     self._c.define(lang, rcb)
Jun 05 16:42:59   File "/opt/conda/lib/python3.6/site-packages/torch/jit/__init__.py", line 904, in _try_compile_weak_script
Jun 05 16:42:59     compiled_fn = torch.jit.script(fn, True, 0, entry["rcb"])
Jun 05 16:42:59   File "/opt/conda/lib/python3.6/site-packages/torch/jit/__init__.py", line 1050, in script
Jun 05 16:42:59     fn = torch._C._jit_script_compile(ast, _rcb, get_default_args(obj))
Jun 05 16:42:59 RuntimeError: 
Jun 05 16:42:59 unknown builtin op: aten::finfo
Jun 05 16:42:59 Could not find any similar ops to aten::finfo. This op may not exist or may not be currently supported in TorchScript
Jun 05 16:42:59 :
Jun 05 16:42:59 at /opt/conda/lib/python3.6/site-packages/torch/nn/functional.py:1323:12
Jun 05 16:42:59         >>> F.gumbel_softmax(logits, tau=1, hard=True)
Jun 05 16:42:59 
Jun 05 16:42:59     .. _Gumbel-Softmax distribution:
Jun 05 16:42:59         https://arxiv.org/abs/1611.00712
Jun 05 16:42:59         https://arxiv.org/abs/1611.01144
Jun 05 16:42:59     """
Jun 05 16:42:59     if eps != 1e-10:
Jun 05 16:42:59         warnings.warn("`eps` parameter is deprecated and has no effect.")
Jun 05 16:42:59 
Jun 05 16:42:59     tiny = torch.finfo(logits.dtype).tiny
Jun 05 16:42:59            ~~~~~~~~~~~ <--- HERE
Jun 05 16:42:59     gumbels = -torch.empty_like(logits).exponential_().clamp_(min=tiny).log()  # ~Gumbel(0,1)
Jun 05 16:42:59     gumbels = (logits + gumbels) / tau  # ~Gumbel(logits,tau)
Jun 05 16:42:59     y_soft = gumbels.softmax(dim)
Jun 05 16:42:59 
Jun 05 16:42:59     if hard:
Jun 05 16:42:59         # Straight through.
Jun 05 16:42:59         index = y_soft.max(dim, keepdim=True)[1]
Jun 05 16:42:59         y_hard = torch.zeros_like(logits).scatter_(dim, index, 1.0)
Jun 05 16:42:59         ret = y_hard - y_soft.detach() + y_soft
Jun 05 16:42:59 
Jun 05 16:42:59 ======================================================================
Jun 05 16:42:59 ERROR: test_nn_gumbel_softmax_hard (__main__.TestJitGeneratedFunctional)
Jun 05 16:42:59 ----------------------------------------------------------------------
Jun 05 16:42:59 Traceback (most recent call last):
Jun 05 16:42:59   File "test_jit.py", line 14962, in wrapper
Jun 05 16:42:59     return fn(*args, **kwargs)
Jun 05 16:42:59   File "test_jit.py", line 15012, in do_test
Jun 05 16:42:59     run_test()
Jun 05 16:42:59   File "test_jit.py", line 15004, in run_test
Jun 05 16:42:59     check_against_reference(self, script_fn, fn, f_args_variable, kwargs_variable, no_grad=no_grad)
Jun 05 16:42:59   File "test_jit.py", line 14242, in check_against_reference
Jun 05 16:42:59     outputs_test = self.runAndSaveRNG(func, nograd_inputs, kwargs)
Jun 05 16:42:59   File "test_jit.py", line 682, in runAndSaveRNG
Jun 05 16:42:59     results = func(*inputs, **kwargs)
Jun 05 16:42:59   File "test_jit.py", line 14192, in script_fn
Jun 05 16:42:59     CU = torch.jit.CompilationUnit(script)
Jun 05 16:42:59   File "/opt/conda/lib/python3.6/site-packages/torch/jit/__init__.py", line 869, in __init__
Jun 05 16:42:59     self.define(lang, _frames_up=_frames_up + 1)
Jun 05 16:42:59   File "/opt/conda/lib/python3.6/site-packages/torch/jit/__init__.py", line 874, in define
Jun 05 16:42:59     self._c.define(lang, rcb)
Jun 05 16:42:59   File "/opt/conda/lib/python3.6/site-packages/torch/jit/__init__.py", line 904, in _try_compile_weak_script
Jun 05 16:42:59     compiled_fn = torch.jit.script(fn, True, 0, entry["rcb"])
Jun 05 16:42:59   File "/opt/conda/lib/python3.6/site-packages/torch/jit/__init__.py", line 1050, in script
Jun 05 16:42:59     fn = torch._C._jit_script_compile(ast, _rcb, get_default_args(obj))
Jun 05 16:42:59 RuntimeError: 
Jun 05 16:42:59 unknown builtin op: aten::finfo
Jun 05 16:42:59 Could not find any similar ops to aten::finfo. This op may not exist or may not be currently supported in TorchScript
Jun 05 16:42:59 :
Jun 05 16:42:59 at /opt/conda/lib/python3.6/site-packages/torch/nn/functional.py:1323:12
Jun 05 16:42:59         >>> F.gumbel_softmax(logits, tau=1, hard=True)
Jun 05 16:42:59 
Jun 05 16:42:59     .. _Gumbel-Softmax distribution:
Jun 05 16:42:59         https://arxiv.org/abs/1611.00712
Jun 05 16:42:59         https://arxiv.org/abs/1611.01144
Jun 05 16:42:59     """
Jun 05 16:42:59     if eps != 1e-10:
Jun 05 16:42:59         warnings.warn("`eps` parameter is deprecated and has no effect.")
Jun 05 16:42:59 
Jun 05 16:42:59     tiny = torch.finfo(logits.dtype).tiny
Jun 05 16:42:59            ~~~~~~~~~~~ <--- HERE
Jun 05 16:42:59     gumbels = -torch.empty_like(logits).exponential_().clamp_(min=tiny).log()  # ~Gumbel(0,1)
Jun 05 16:42:59     gumbels = (logits + gumbels) / tau  # ~Gumbel(logits,tau)
Jun 05 16:42:59     y_soft = gumbels.softmax(dim)
Jun 05 16:42:59 
Jun 05 16:42:59     if hard:
Jun 05 16:42:59         # Straight through.
Jun 05 16:42:59         index = y_soft.max(dim, keepdim=True)[1]
Jun 05 16:42:59         y_hard = torch.zeros_like(logits).scatter_(dim, index, 1.0)
Jun 05 16:42:59         ret = y_hard - y_soft.detach() + y_soft

cc @wanchaol

@ezyang ezyang added the oncall: jit Add this issue/PR to JIT oncall triage queue label Jun 7, 2019
@zou3519 zou3519 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jun 7, 2019
@wanchaol
Copy link
Collaborator

wanchaol commented Jun 7, 2019

I think it's because torch.finfo is not a JIT recognized type and JIT thought is a ATen op instead. Sorry I did not have a full context why we need to do clamp(min_representive_number) here, seems like dealing the edge cases, can we do it in other ways instead?

@ezyang
Copy link
Contributor

ezyang commented Jun 11, 2019 via email

@facebook-github-bot
Copy link
Contributor

Hi @zhaoyanpeng!

Thank you for your pull request and welcome to our community.

Action Required

In order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at [email protected]. Thanks!

@github-actions
Copy link
Contributor

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a mistake.
If you are unable to remove the Stale label please contact a maintainer in order to do so.
If you want the bot to never mark this PR stale again, add the no-stale label.
Stale pull requests will automatically be closed after 30 days of inactivity.

@github-actions github-actions bot added the Stale label May 28, 2022
@github-actions github-actions bot closed this Jun 27, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

module: nn Related to torch.nn oncall: jit Add this issue/PR to JIT oncall triage queue open source Stale triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants