Skip to content

Conversation

@ragulpr
Copy link
Contributor

@ragulpr ragulpr commented Oct 30, 2018

Fixes #12643, amends to #3341.

  • Allow multidimensional input (but apply softmax over dim=-1) with dim argument
  • Cleaner: Less lines of code
  • Faster (1.32x speedup vs original, 2x speedup vs using torch.Distributions)
  • Small fixes in docstring
  • Remove some references in docstring. Was the linked (excellent) ipynb the first to do the straight-through trick? Instead, I propose changing to reference to the two papers most known for it.
  • Add deprecationwarning for eps. It's not needed anymore.
  • Initial commit keeps some code alternatives commented to exploit CI

Controversies:

  • As of discussion when gumbel_softmax was added (add gumbel_softmax, based on Eric Jang's implementation #3341), this was merged into torch.nn.functional before all the work with Distributions and Pyro, and there will probably be multiple other best practices for this in the future.
    I've tested building using the Distributions-api, but it was too slow, see below.

I therefore propose not using Distributions to keep it fast and simple, but adding a comment in docstring that gumbel_softmax may be deprecated in the future.

Build on torch.distributions.RelaxedOneHotCategorical?

dist = torch.distributions.RelaxedOneHotCategorical(temperature=tau, logits=logits, validate_args=False)
y_soft = dist.rsample()

Pros:

  • Built using tricks like logsumexp etc
  • Explicitly uses torch.distributions.utils._finfo to avoid overflow (old implementation had an eps flag)
  • Maintained for this exact purpose.

Cons:

  • Very slow. Construction of distribution adds overhead see timings below. May be solved in future with speedups of TransformedDistribution and Distribution.
  • Assumes which dim to apply softmax over.

Build on torch.exponential_() (as proposed)?

    y_soft = logits.new(logits.shape)
    y_soft = (logits - y_soft.exponential_().log()) / tau  # Gumbel noise
    y_soft = y_soft.softmax(dim)  # Gumbel softmax noise

Pros:

  • Faster

Timings:

    import time
    start = time.time()
    num_draws = 1000000
    logits = torch.randn(1,3)

    for draw in range(num_draws):
        y_draw = gumbel_softmax(logits, hard=True)
        counts = counts + y_draw
    print(end - start)        

## torch.nn.functional.gumbel_softmax()
>> 12.995795965194702

## Using exponential_() as this commit
>> 7.658372640609741

## Using RelaxedOneHotCategorical
>> 20.3382670879364

TODO

Decide on which path to chose. I'll commit in changes to the unit tests in a while to show that it passes both old tests and new tests. I'll also remove the commented code about RelaxedOneHotCategorical

- Cleanup : faster and much less lines of code.
- Remove some references in docstring. Was linked ipynb the first to do hard trick?
- Add reference to the two papers published in parallell.
- Add deprecationwarning for `eps`. It probably shouldn't be there.
- Keep some code alternatives commented, I want to exploit the pytorchtester
@ragulpr
Copy link
Contributor Author

ragulpr commented Oct 30, 2018

TODO:

  • Flake8
  • Add new tests
  • Decide on whether to use torch.Distributions-api or .exponential_() and then remove commented code
    # dist = torch.distributions.RelaxedOneHotCategorical(

- Tests same things as before except; 
- Removed test for variance of gradient. Was unclear how/why it worked.
- Tests for 3-dimensional array
- `st`->`straight_through`
@ragulpr
Copy link
Contributor Author

ragulpr commented Oct 30, 2018

Only error with original test was caused by flake8 (hindering other pull requests to merge currently)

$ flake8
./torch/nn/functional.py:1173:121: E501 line too long (130 > 120 characters)
./torch/nn/functional.py:1247:121: E501 line too long (130 > 120 characters)
./torch/nn/functional.py:1248:121: E501 line too long (123 > 120 characters)

@ragulpr
Copy link
Contributor Author

ragulpr commented Nov 1, 2018

I think this is ready to be merged.

  • It worked with the original tests,
  • the new tests are more rigorous.
  • the api has not changed so no breaking changes.

Here is a comparison between implementations with full training example https://gist.github.com/ragulpr/1f88aee9fcc585600280ba2d6f323368

Or nbviewer if above link not working:
https://nbviewer.jupyter.org/urls/gist.githubusercontent.com/ragulpr/1f88aee9fcc585600280ba2d6f323368/raw/ed5ae3e85064768886efd656b50261ed2f590633/gumbel_softmax.ipynb

@ragulpr
Copy link
Contributor Author

ragulpr commented Nov 1, 2018

Anyone who could review? I know @soumith reviewed original PR 👍

@alicanb
Copy link
Collaborator

alicanb commented Nov 1, 2018

It would be nice if we keep implementations of RelaxedOneHotCategorical.rsample and F.gumbel_softmax consistent.

@ragulpr
Copy link
Contributor Author

ragulpr commented Nov 2, 2018

Thanks @alicanb I've been looking into how that would be possible. I think it's two things; consistent in values or consistent in style. Long story short, the line-by-line translation using reference implementation

Alt a)

    uniforms = clamp_probs(torch.rand(logits.shape, dtype=logits.dtype, device=logits.device))
    gumbels = -((-(uniforms.log())).log()) # -log(exponential) = gumbel(0,1)
    scores = (logits + gumbels) / tau # gumbel(logits,tau)
    y_soft = scores - scores.logsumexp(dim=dim, keepdim=True)
    # need to apply transform(x) as it's a TransformedDistribution
    # https://github.com/pytorch/pytorch/blob/99ce499bfecd4ed0e87cf60d87ee483c6383d95c/torch/distributions/transformed_distribution.py#L99
    y_soft = y_soft.exp() 

14.4 s ± 86 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

But this is going to be slow and logsumexp unnecessary. To speed this up and leave a smaller memory footprint by reusing gumbels variable we get below. This is still consistent in values (generates same thing)

Alt b)

    # merge logsumexp and exp into softmax, reuse variable names
    uniforms = clamp_probs(torch.rand(logits.shape, dtype=logits.dtype, device=logits.device))
    gumbels = -((-(uniforms.log())).log()) # -log(~Exp(1)) = ~Gumbel(0,1)
    gumbels = (logits + gumbels) / tau # ~Gumbel(logits,tau)
    y_soft = gumbels.softmax(dim)

11.5 s ± 130 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Third alternative is to impose the more legible 'style' of relaxed_categorical rsample implementation but use exponential_() and .new(). This is slightly different in values as it doesn't clamp_probs as aggressively.

Alt c)

    gumbels = -logits.new(logits.shape).exponential_().log() # ~Gumbel(0,1)
    gumbels = (logits + gumbels) / tau  # ~Gumbel(logits,tau)
    y_soft = gumbels.softmax(dim)  # Gumbel softmax noise

11.1 s ± 58.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

For reference, this was what I committed;

Alt d)

    y_soft = logits.new(logits.shape)
    y_soft = (logits - y_soft.exponential_().log()) / tau  # Gumbel noise
    y_soft = y_soft.softmax(dim)  # Gumbel softmax noise

10.6 s ± 59 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

And using finally using RelaxedOneHotCategorical without inlining

    dist = torch.distributions.RelaxedOneHotCategorical(
        temperature=tau, logits=logits, validate_args=False)
    y_soft = dist.rsample()

21.2 s ± 326 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

In the end, if we want to mimic in values I propose going with alt b) otherwise alt c)

@ezyang
Copy link
Contributor

ezyang commented Nov 16, 2018

@alicanb Do you have an opinion about the implementation strategies?

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ezyang has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ezyang has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@ezyang
Copy link
Contributor

ezyang commented Dec 6, 2018

I'm not sure what's going on with the failures, because the mentioned code doesn't show up in the PR at all. Pushed a merge to master and trying again.

ragulpr and others added 2 commits December 7, 2018 14:03
fixes Jit-error `return statements can appear only at the end of the function body`. Unfortunately uses the less nice `ret` pattern (but it seems to be standard)
@ezyang
Copy link
Contributor

ezyang commented Dec 20, 2018

Unfortunately, changes in JIT mean that this patch no longer works, because clamp_probs is not scriptable.

@wanchaol could you take a look at this?

@wanchaol
Copy link
Collaborator

Yes this will make jit standard lib break , @ragulpr why we ended up choosing Alt b) instead of Alt c) given that the Alt c) looks simpler and faster?

@ragulpr
Copy link
Contributor Author

ragulpr commented Dec 21, 2018

Sorry for slow activity here, got a little bit taken back by all JIT-errors.

The reason for choosing b)

    uniforms = torch.rand_like(logits)
    uniforms = torch.distributions.utils.clamp_probs(uniforms)
    gumbels = -((-(uniforms.log())).log())  # -log(~Exp(1)) = ~Gumbel(0,1)
    gumbels = (logits + gumbels) / tau  # ~Gumbel(logits,tau)
    y_soft = gumbels.softmax(dim)

Instead of c)

    gumbels = -logits.new(logits.shape).exponential_().log() # ~Gumbel(0,1)
    gumbels = (logits + gumbels) / tau  # ~Gumbel(logits,tau)
    y_soft = gumbels.softmax(dim)  # Gumbel softmax noise

Was because alicanb had the good point that implementation be similar to RelaxedOneHotCategorical and b) generates identical noise to RelaxedOneHotCategorical. Now that we know it wont work with JIT maybe b) is not worth the hassle. I just need to verify first that .exponential_() will work with JIT too, in that case that is preferable

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ezyang has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ezyang is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[feature request] Gumbel Softmax for 3D tensor

5 participants