Skip to content

Conversation

@anjali411
Copy link
Contributor

@anjali411 anjali411 commented Sep 17, 2019

added more variables to EmbeddingOptions and updated EmbeddingImpl reset, forward functions. Also added EmbeddingBag.


This PR is BC-breaking in the following way:

Previously, EmbeddingOptions supports count and dimension as options arguments. After this PR, they are renamed to num_embeddings and embedding_dim respectively.

@pytorchbot pytorchbot added the module: cpp Related to C++ API label Sep 17, 2019
Copy link
Contributor

@yf225 yf225 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot @anjali411 ! The PR looks awesome as a start. I left some comments regarding the module options and how we implement the constructor and forward.

struct TORCH_API EmbeddingOptions {
EmbeddingOptions(int64_t count, int64_t dimension);
/// The number of embeddings (number of rows in the table).
// The number of embeddings (number of rows in the table).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also I think we should change the comment to match the Python side:

num_embeddings (int): size of the dictionary of embeddings

// The number of embeddings (number of rows in the table).
TORCH_ARG(int64_t, count);
/// The size of each embedding vector (number of columns in the table).
// The size of each embedding vector (number of columns in the table).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto for matching Python side:

embedding_dim (int): the size of each embedding vector

TORCH_ARG(int64_t, num_embeddings);
// The size of each embedding vector (number of columns in the table).
TORCH_ARG(int64_t, embedding_dim);
// If given, pads the output with the embedding vector at :attr:`padding_idx (initialized to zeros) whenever it encounters the index.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

:attr: probably doesn't work in C++ API docs now, and we can change it to:

Suggested change
// If given, pads the output with the embedding vector at :attr:`padding_idx (initialized to zeros) whenever it encounters the index.
// If given, pads the output with the embedding vector at `padding_idx` (initialized to zeros) whenever it encounters the index.

// The p of the p-norm to compute for the :attr:`max_norm` option. Default ``2``
TORCH_ARG(float, norm_type)=2.;
// If given, this will scale gradients by the inverse of frequency of the words in the mini-batch. Default ``False``.
TORCH_ARG(bool, scale_grad_by_freq)=false;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

Suggested change
TORCH_ARG(bool, scale_grad_by_freq)=false;
TORCH_ARG(bool, scale_grad_by_freq) = false;

// If given, pads the output with the embedding vector at :attr:`padding_idx (initialized to zeros) whenever it encounters the index.
TORCH_ARG(c10::optional<int64_t>, padding_idx)=c10::nullopt;
// If given, each embedding vector with norm larger than :attr:`max_norm` is renormalized to have norm :attr:`max_norm`.
TORCH_ARG(c10::optional<float>, max_norm)=c10::nullopt;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might not need to specify c10::nullopt for padding_idx and max_norm, because the default value of c10::optional should already be c10::nullopt.

else{
assert((padding_idx >= -num_embeddings) && "Padding_idx must be within num_embedding");
*padding_idx = *padding_idx+num_embeddings;
options.padding_idx_ = padding_idx;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
options.padding_idx_ = padding_idx;
options.padding_idx(padding_idx);

}
else{
assert((*options.padding_idx() >= -(*options.weight()).size(0)) && "Padding_idx must be within num_embedding");
options.padding_idx(*options.padding_idx() + (*options.weight_).size(0));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor nit: we should likely swap the order here to match the Python implementation:

Suggested change
options.padding_idx(*options.padding_idx() + (*options.weight_).size(0));
options.padding_idx((*options.weight_).size(0) + *options.padding_idx());

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also options.weight_ can be changed to options._weight(), assuming we rename weight to _weight.

}

if(options.max_norm() != c10::nullopt){
input.contiguous();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.contiguous() is not an in-place function, and we should likely do:

Suggested change
input.contiguous();
input = input.contiguous();

This also matches the Python implementation.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also think we need to implement and call _no_grad_embedding_renorm_ here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for _no_grad_embedding_renorm_, we can do this:

#include <torch/utils.h>
{
  torch::NoGradGuard no_grad;
  torch::embedding_renorm(...);
}

if(options.max_norm() != c10::nullopt){
input.contiguous();
}
return torch::embedding(*options.weight(), /*indices=*/input, *options.padding_idx(), options.scale_grad_by_freq(), options.sparse());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should pass weight instead of *options.weight(), once we put weight back as the module's attribute.

"weight", torch::empty({options.count_, options.dimension_}));
NoGradGuard guard;
weight.normal_(0, 1);
(*(options.weight_)).Tensor::normal_(0, 1);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can use torch::nn::init::normal_(weight), if we include the torch/nn/init.h header.

Copy link
Contributor

@yf225 yf225 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @anjali411 ! I left some comments

ASSERT_EQ(
c10::str(Embedding(10, 2)),
"torch::nn::Embedding(count=10, dimension=2)");
c10::str(Embedding(num_embeddings=10, embedding_dim=2)),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

keyword arguments don't work here because we are dealing with C++ :/ The expectation is that the user knows the first two arguments are num_embeddings and embedding_dim, and they can just call:

Suggested change
c10::str(Embedding(num_embeddings=10, embedding_dim=2)),
c10::str(Embedding(10, 2)),

c10::str(Embedding(num_embeddings=10, embedding_dim=2)),
"torch::nn::Embedding(num_embeddings=10, embedding_dim=2)");
ASSERT_EQ(
c10::str(Embedding(num_embeddings=10, embedding_dim=2, padding_idx=3, max_norm=2)),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this one involves optional arguments, we would write:

Suggested change
c10::str(Embedding(num_embeddings=10, embedding_dim=2, padding_idx=3, max_norm=2)),
c10::str(Embedding(EmbeddingOptions(10, 2).padding_idx(3).max_norm(2))),

c10::str(Embedding(num_embeddings=10, embedding_dim=2, padding_idx=3, max_norm=2)),
"torch::nn::Embedding(num_embeddings=10, embedding_dim=2, padding_idx=3, max_norm=2)");
ASSERT_EQ(
c10::str(Embedding(num_embeddings=10, embedding_dim=2, padding_idx=3, max_norm=2, norm_type=2.5, scale_grad_by_freq=true, sparse=true)),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto here, we would write:

Suggested change
c10::str(Embedding(num_embeddings=10, embedding_dim=2, padding_idx=3, max_norm=2, norm_type=2.5, scale_grad_by_freq=true, sparse=true)),
c10::str(Embedding(EmbeddingOptions(10, 2).padding_idx(3).max_norm(2).norm_type(2.5).scale_grad_by_freq(true).sparse(true))),

fc(register_module("fc", torch::nn::Linear(4, 5))),
table(register_module("table", torch::nn::Embedding(10, 2))),
table(register_module("table", torch::nn::Embedding(10, 2, padding_idx=3, max_norm=2))),
table(register_module("table", torch::nn::Embedding(10, 2, padding_idx=3, max_norm=2, norm_type=2.5, scale_grad_by_freq=true, sparse=true))),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto here, we would need to use EmbeddingOptions to express both the required arguments and optional arguments.

" (table): torch::nn::Embedding(count=10, dimension=2)\n"
" (table): torch::nn::Embedding(num_embeddings=10, embedding_dim=2)\n"
" (table): torch::nn::Embedding(num_embeddings=10, embedding_dim=2, padding_idx=3, max_norm=2)"
" (table): torch::nn::Embedding(num_embeddings=10, embedding_dim=2, padding_idx=3, max_norm=2, norm_type=2.5, scale_grad_by_freq=true, sparse=true)"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto here, we would need to use EmbeddingOptions to express both the required arguments and optional arguments.

return torch::embedding(weight, /*indices=*/input);
if(options.padding_idx() != c10::nullopt){
if(*options.padding_idx() > 0){
assert((*options.padding_idx() < (weight.size(0)) && "Padding_idx must be within num_embeddings");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto for TORCH_CHECK(condition, message)

assert((*options.padding_idx() < (weight.size(0)) && "Padding_idx must be within num_embeddings");
}
else{
assert((*options.padding_idx() >= -(weight.size(0)) && "Padding_idx must be within num_embedding");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto for TORCH_CHECK(condition, message)

if(*options.padding_idx() > 0){
assert((*options.padding_idx() < (weight.size(0)) && "Padding_idx must be within num_embeddings");
}
else{
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same situation here, we should check } else if (*options.padding_idx() < 0) { instead, to match the Python implementation

if(options.max_norm() != c10::nullopt){
input = input.contiguous();
torch::NoGradGuard no_grad;
torch::embedding_renorm(weight, input, *options.max_norm(), options.norm_type());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We likely want torch::embedding_renorm_ instead of torch::embedding_renorm

if(options.scale_grad_by_freq()){
stream << ",scale_grad_by_freq=" << options.scale_grad_by_freq();
}
if(options.sparse()){
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A note on formatting: in general we would want to write

if (options.sparse()) {

instead of

if(options.sparse()){

, and

} else {

instead of

}
else {

to be consistent with the formatting of the other parts of C++ API.

@yf225
Copy link
Contributor

yf225 commented Sep 19, 2019

@pytorchbot rebase this please

Copy link
Contributor

@yf225 yf225 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did an initial pass, and will do another one tomorrow.

Tensor forward(const Tensor& indices);

static EmbeddingImpl& from_pretrained(Tensor embeddings, bool freeze = true, c10::optional<int64_t> padding_idx = c10::nullopt,
c10::optional<float> max_norm = c10::nullopt, float norm_type = 2., bool scale_grad_by_freq = false, bool sparse = false);
Copy link
Contributor

@yf225 yf225 Sep 20, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for giving the wrong idea about how to implement from_pretrained: I think we will need to move from_pretrained to the Embedding class, not EmbeddingImpl, because we want people to be able to use it with torch::nn::Embedding::from_pretrained(...). I wrote a gist to illustrate the idea: https://gist.github.com/yf225/8eee0ef3f6afd2317092900927a43994.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should also use EmbeddingOptions instead of writing out the parameters explicitly, as shown in the gist.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also the return type shouldn't be a reference, and should just be Embedding (after we move this to the Embedding class).

c10::optional<torch::Tensor> per_sample_weights = c10::nullopt);

static EmbeddingBagImpl& EmbeddingBagImpl::from_pretrained(Tensor embeddings, bool freeze = true, c10::optional<float> mex_norm = c10::nullopt,
float norm_type = 2., bool scale_grad_by_freq = false, string mode = "sum", bool sparse = false);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto here: we would need to move from_pretrained to EmbeddingBag class, so that people can do torch::nn::EmbeddingBag::from_pretrained(...).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto for using EmbeddingBagOptions as well.

torch::nn::init.normal_(weight);
}
else {
TORCH_CHECK((weight.size(0) == options.num_embeddings()) && (weight.size(1) == options.embedding_dim()), "Shape of _weight does not match num_embeddings and embedding_dim");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there are two issues here:

  1. We should be checking the sizes of _weight instead of weight, to match the Python implementation.
  2. In the error message, it should say Shape of weight instead of Shape of _weight, to match the Python implementation.

We can also improve the size checking with the following, to better match Python side:

Suggested change
TORCH_CHECK((weight.size(0) == options.num_embeddings()) && (weight.size(1) == options.embedding_dim()), "Shape of _weight does not match num_embeddings and embedding_dim");
TORCH_CHECK((*options._weight()).sizes() == torch::IntArrayRef({options.num_embeddings(), options.embedding_dim()}), "Shape of weight does not match num_embeddings and embedding_dim"

std::tuple<Tensor, Tensor, Tensor, Tensor> EmbeddingBagImpl::forward(const Tensor& input, c10::optional<torch::Tensor> offsets,
c10::optional<torch::Tensor> per_sample_weights) {

TORCH_CHECK(per_sample_weights == c10::nullopt || ((input.size(0) == per_sample_weights.size(0)) && input.size(1) == per_sample_weights.size(1)),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can simplify this as:

Suggested change
TORCH_CHECK(per_sample_weights == c10::nullopt || ((input.size(0) == per_sample_weights.size(0)) && input.size(1) == per_sample_weights.size(1)),
TORCH_CHECK(per_sample_weights == c10::nullopt || input.sizes() == per_sample_weights.sizes())


TORCH_CHECK(per_sample_weights == c10::nullopt || ((input.size(0) == per_sample_weights.size(0)) && input.size(1) == per_sample_weights.size(1)),
"embedding_bag: If per_sample_weights ({", per_sample_weights.size(0), ", ", per_sample_weights.size(1), "}) is not null,
then it must have the same shape as the input ({", input.size(0), ", ", input.size(1), "})\n");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we might want to write the error message as following:

     "embedding_bag: If per_sample_weights (", per_sample_weights.sizes(), ") is not null, ",
     "then it must have the same shape as the input (", input.sizes(), ")");

Specifically there are a few issues with the original error message:

  1. The spaces before then it must will be printed to the screen when the error message shows, which is likely not what we want.
  2. We don't need \n at the end of the message because TORCH_CHECK will handle it automatically.
  3. We can use .sizes() to simplify the size printing.

if(input.dim() == 2) {
TORCH_CHECK(offsets == c10::nullopt,
"if input is 2D, then offsets has to be null, as input is treated is a mini-batch of
fixed length sequences. However, found an offsets Tensor"); //check about adding type
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto here: we likely shouldn't put spaces before fixed length sequences, and should do

        "if input is 2D, then offsets has to be null, as input is treated is a mini-batch of ",
        "fixed length sequences ...");

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can probably just write However, found offsets of type Tensor, since this is enough information for the user to fix the issue.

}
}

if (!options._weight().has_value()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might want to do if (options._weight() == c10::nullopt) { to look more similar to Python implementation.

stream << ",norm_type=" << options.norm_type();
}
if(options.scale_grad_by_freq()) {
stream << ",scale_grad_by_freq=" << options.scale_grad_by_freq();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
stream << ",scale_grad_by_freq=" << options.scale_grad_by_freq();
stream << ", scale_grad_by_freq=" << options.scale_grad_by_freq();

if(options.scale_grad_by_freq()) {
stream << ",scale_grad_by_freq=" << options.scale_grad_by_freq();
}
stream << ",mode="<<mode<<")";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
stream << ",mode="<<mode<<")";
stream << ", mode=" << mode << ")";


EmbeddingBagImpl& EmbeddingBagImpl::from_pretrained(Tensor embeddings, bool freeze = true, c10::optional<float> mex_norm = c10::nullopt,
float norm_type = 2., bool scale_grad_by_freq = false, string mode = "sum", bool sparse = false) {
TORCH_CHECK(embeddings.dim() == 2, "Embeddings parameter is expected to be 2-embedding_dimal");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
TORCH_CHECK(embeddings.dim() == 2, "Embeddings parameter is expected to be 2-embedding_dimal");
TORCH_CHECK(embeddings.dim() == 2, "Embeddings parameter is expected to be 2-dimensional");

norm_type=norm_type,
scale_grad_by_freq=scale_grad_by_freq,
mode=mode,
sparse=sparse);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to pass the EmbeddingBag options to the constructor, because keyword argument is not supported in C++.

Tensor forward(const Tensor& indices);

static EmbeddingImpl& from_pretrained(Tensor embeddings, bool freeze = true, c10::optional<int64_t> padding_idx = c10::nullopt,
c10::optional<float> max_norm = c10::nullopt, float norm_type = 2., bool scale_grad_by_freq = false, bool sparse = false);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also the return type shouldn't be a reference, and should just be Embedding (after we move this to the Embedding class).

facebook-github-bot pushed a commit that referenced this pull request Sep 20, 2019
Summary:
With this PR, we establish the following conventions:
1. Options in C++ module / optimizer constructors should always be `const SomeOptions&` type, not `SomeOptions` type.
2. The options constructor arg should always be named `options_`, not `options`, to not be confused with the module / optimizer's internal field `options`.
3. We never use `std::move` to assign `options_` to the module / optimizer's internal field `options` in the constructor definition. Instead, we simply use `options(options_)`.

Here is the reasoning:
We might be tempted to declare the constructor as `SomeModule(SomeOptions options_)` and have `options(std::move(options_))` in the member initialization list. However, this can be a dangerous design because the constructor might use `options_` to set values for other member fields in the member initialization list (e.g. https://github.com/pytorch/pytorch/blob/8317f75b79fb78ceeeb928aa23a901d57274b9e1/torch/csrc/api/include/torch/optim/lbfgs.h#L30-L34), and use-after-move can cause hard-to-debug problems.
Instead, we choose to explicitly use `const SomeOptions&` type for `options_`, and never use `std::move` to assign it to the internal `options` field. This way we have stronger guarantee on the validity of `options_` at any point in the constructor.

Notable exceptions to the above conventions:
1. C++ Embedding module doesn't adhere to the conventions now, which will be fixed after #26358 is landed.
2. C++ dataloader and dataset classes likely need similar changes. We will do it when we start to work on dataloader/dataset parity.

Thanks ShahriarSS for discovering the options usage inconsistency! 🚀
Pull Request resolved: #26483

Differential Revision: D17500451

Pulled By: yf225

fbshipit-source-id: 49361a3519e4ede933789db75731d40144f0b617
@yf225
Copy link
Contributor

yf225 commented Sep 23, 2019

@anjali411 There seems to be a conflict with master - please feel free to just use your version :)

" (table): torch::nn::Embedding(num_embeddings=10, embedding_dim=2, padding_idx=3, max_norm=2, norm_type=2.5, scale_grad_by_freq=true, sparse=true)"
" (inner): InnerTestModule(\n"
" (fc): torch::nn::Linear(in=3, out=4, with_bias=true)\n"
" (table): torch::nn::Embedding(count=10, dimension=2)\n"
Copy link
Contributor

@yf225 yf225 Sep 23, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We probably need to add tests for EmbeddingBag and Embedding::from_pretrained to the test suite. Some examples are as follows, copied from Python docs (https://pytorch.org/docs/stable/nn.html#embedding):

>>> # an Embedding module containing 10 tensors of size 3
>>> embedding_sum = nn.EmbeddingBag(10, 3, mode='sum')
>>> # a batch of 2 samples of 4 indices each
>>> input = torch.LongTensor([1,2,4,5,4,3,2,9])
>>> offsets = torch.LongTensor([0,4])
>>> embedding_sum(input, offsets)
tensor([[-0.8861, -5.4350, -0.0523],
        [ 1.1306, -2.5798, -1.0044]])
>>> # FloatTensor containing pretrained weights
>>> weight = torch.FloatTensor([[1, 2.3, 3], [4, 5.1, 6.3]])
>>> embeddingbag = nn.EmbeddingBag.from_pretrained(weight)
>>> # Get embeddings for index 1
>>> input = torch.LongTensor([[1, 0]])
>>> embeddingbag(input)
tensor([[ 2.5000,  3.7000,  4.6500]])
>>> # FloatTensor containing pretrained weights
>>> weight = torch.FloatTensor([[1, 2.3, 3], [4, 5.1, 6.3]])
>>> embedding = nn.Embedding.from_pretrained(weight)
>>> # Get embeddings for index 1
>>> input = torch.LongTensor([1])
>>> embedding(input)
tensor([[ 4.0000,  5.1000,  6.3000]])

mingbowan pushed a commit to mingbowan/pytorch that referenced this pull request Sep 23, 2019
Summary:
With this PR, we establish the following conventions:
1. Options in C++ module / optimizer constructors should always be `const SomeOptions&` type, not `SomeOptions` type.
2. The options constructor arg should always be named `options_`, not `options`, to not be confused with the module / optimizer's internal field `options`.
3. We never use `std::move` to assign `options_` to the module / optimizer's internal field `options` in the constructor definition. Instead, we simply use `options(options_)`.

Here is the reasoning:
We might be tempted to declare the constructor as `SomeModule(SomeOptions options_)` and have `options(std::move(options_))` in the member initialization list. However, this can be a dangerous design because the constructor might use `options_` to set values for other member fields in the member initialization list (e.g. https://github.com/pytorch/pytorch/blob/8317f75b79fb78ceeeb928aa23a901d57274b9e1/torch/csrc/api/include/torch/optim/lbfgs.h#L30-L34), and use-after-move can cause hard-to-debug problems.
Instead, we choose to explicitly use `const SomeOptions&` type for `options_`, and never use `std::move` to assign it to the internal `options` field. This way we have stronger guarantee on the validity of `options_` at any point in the constructor.

Notable exceptions to the above conventions:
1. C++ Embedding module doesn't adhere to the conventions now, which will be fixed after pytorch#26358 is landed.
2. C++ dataloader and dataset classes likely need similar changes. We will do it when we start to work on dataloader/dataset parity.

Thanks ShahriarSS for discovering the options usage inconsistency! 🚀
Pull Request resolved: pytorch#26483

Differential Revision: D17500451

Pulled By: yf225

fbshipit-source-id: 49361a3519e4ede933789db75731d40144f0b617
@yf225
Copy link
Contributor

yf225 commented Oct 1, 2019

@pytorchbot rebase this please

@pytorchbot
Copy link
Collaborator

There's nothing to do! This branch is already up to date with master (46539ee).

(To learn more about this bot, see Bot commands.)

@yf225 yf225 added the module: bc-breaking Related to a BC-breaking change label Oct 1, 2019
Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@anjali411 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@yf225
Copy link
Contributor

yf225 commented Oct 8, 2019

@pytorchbot rebase this please

Copy link
Contributor

@yf225 yf225 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks so much for the awesome work @anjali411 !

Note to self: write BC-breaking notes before landing this PR.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yf225 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@yf225 merged this pull request in a37be20.

thiagocrepaldi pushed a commit to thiagocrepaldi/pytorch that referenced this pull request Feb 4, 2020
…ch#26358)

Summary:
added more variables to EmbeddingOptions and updated EmbeddingImpl reset, forward functions. Also added EmbeddingBag.

-----

This PR is BC-breaking in the following way:

Previously, `EmbeddingOptions` supports `count` and `dimension` as options arguments. After this PR, they are renamed to `num_embeddings` and `embedding_dim` respectively.
Pull Request resolved: pytorch#26358

Differential Revision: D17714337

Pulled By: yf225

fbshipit-source-id: f9f969c68e4bece106b92f8e2e02ac39c8455fb7
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Merged module: bc-breaking Related to a BC-breaking change module: cpp Related to C++ API module: third_party open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants