Skip to content

[Contributor Welcome] Implement C++ API version of torch.nn.functional.gumbel_softmax #27078

@yf225

Description

@yf225

Context

We would like to add torch::nn::functional::gumbel_softmax to the C++ API, so that C++ users can easily find the equivalent of Python API torch.nn.functional.gumbel_softmax.

Steps

  • Add torch::nn::GumbelSoftmaxOptions to torch/csrc/api/include/torch/nn/options/activation.h (add this file if it doesn’t exist), which should include the following parameters (based on https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.gumbel_softmax)
    • TORCH_ARG(double, tau) = 1.0
    • TORCH_ARG(bool, hard) = false
    • TORCH_ARG(double, eps) = 1e-10
    • TORCH_ARG(int64_t, dim) = -1
    • NOTE: please make sure to add the same comments for the parameters as in the Python version. For example, for parameter tau we should say non-negative scalar temperature
  • Add torch::nn::functional::gumbel_softmax(...) in torch/csrc/api/include/torch/nn/functional/activation.h (add this file if it doesn’t exist). The function should have the following signature:
namespace torch {
namespace nn {
namespace functional {

inline Tensor gumbel_softmax(
    const Tensor& logits,
    const GumbelSoftmaxOptions& options = {}) {
  ...
}

} // namespace functional
} // namespace nn
} // namespace torch
  • Add test for torch::nn::functional::gumbel_softmax(...) in test/cpp/api/functional.cpp. It can just check whether the function output is as expected for the following cases (expressed in Python):
>>> logits = torch.randn(20, 32)
>>> # Sample soft categorical using reparametrization trick:
>>> F.gumbel_softmax(logits, tau=1, hard=False)
>>> # Sample hard categorical using "Straight-through" trick:
>>> F.gumbel_softmax(logits, tau=1, hard=True)

Helpful Resources

There are quite a few PRs for adding new functionals / new modules for the C++ API (the list of PRs is in #25883), which can serve as great references. Also please ping @yf225 on this issue if you encounter any problems.

How do I claim this feature request?

Please comment in this issue if you are interested in working on it.

cc @yf225

Metadata

Metadata

Assignees

Labels

good first issuemodule: cppRelated to C++ APItriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions