Skip to content

Conversation

@yf225
Copy link
Contributor

@yf225 yf225 commented Mar 5, 2020

This PR refactors RNN / GRU / LSTM layers in C++ API to exactly match the implementation in Python API.

BC-breaking changes:

  • Instead of returning RNNOutput, RNN / GRU forward method now returns std::tuple<Tensor, Tensor>, and LSTM forward method now returns std::tuple<Tensor, std::tuple<Tensor, Tensor>>, matching Python API.
  • RNN / LSTM / GRU forward method now accepts the same inputs (input tensor and optionally hidden state), matching Python API.
  • RNN / LSTM / GRU layers now have forward_with_packed_input method which accepts PackedSequence as input and optionally hidden state, matching the forward(PackedSequence, ...) variant in Python API.
  • RNN / LSTM / GRU layers no longer have these fields: w_ih / w_hh / b_ih / b_hh. Instead, to access the weights and biases of the gates, users should do e.g. rnn->named_parameters()["weight_ih_l0"], which mirrors the Python API rnn.weight_ih_l0.
  • In RNNOptions
    • tanh() / relu() / activation are removed. Instead, nonlinearity is added which takes either torch::kTanh or torch::kReLU
    • layers -> num_layers
    • with_bias -> bias
  • In LSTMOptions
    • layers -> num_layers
    • with_bias -> bias
  • In GRUOptions
    • layers -> num_layers
    • with_bias -> bias

The majority of the changes in this PR focused on refactoring the implementations in torch/csrc/api/src/nn/modules/rnn.cpp to match the Python API. RNN tests are then changed to reflected the revised API design.

@dr-ci
Copy link

dr-ci bot commented Mar 5, 2020

💊 CircleCI build failures summary and remediations

As of commit 358fc46 (more details on the Dr. CI page):


  • 1/1 failures introduced in this PR

🕵️ 1 new failure recognized by patterns

The following build failures do not appear to be due to upstream breakages:

See CircleCI build pytorch_linux_backward_compatibility_check_test (1/1)

Step: "Test" (full log | pattern match details)

Mar 15 01:12:26 The PR is introducing backward incompatible changes to the operator library. Please contact PyTorch team to confirm whether this change is wanted or not.
Mar 15 01:12:26 processing existing schema:  aten::t(Tensor(a) self) -> (Tensor(a)) 
Mar 15 01:12:26 processing existing schema:  aten::to.other(Tensor self, Tensor other, bool non_blocking=False, bool copy=False, int? memory_format=None) -> (Tensor) 
Mar 15 01:12:26 processing existing schema:  aten::to.dtype(Tensor self, int dtype, bool non_blocking=False, bool copy=False, int? memory_format=None) -> (Tensor) 
Mar 15 01:12:26 processing existing schema:  aten::to.device(Tensor self, Device device, int dtype, bool non_blocking=False, bool copy=False, int? memory_format=None) -> (Tensor) 
Mar 15 01:12:26 processing existing schema:  aten::to.dtype_layout(Tensor self, *, int dtype, int layout, Device device, bool pin_memory=False, bool non_blocking=False, bool copy=False, int? memory_format=None) -> (Tensor) 
Mar 15 01:12:26 processing existing schema:  aten::to(Tensor(a) self, Device? device, int? dtype=None, bool non_blocking=False, bool copy=False) -> (Tensor(b|a)) 
Mar 15 01:12:26 processing existing schema:  aten::to(Tensor(a) self, int? dtype=None, bool non_blocking=False, bool copy=False) -> (Tensor(b|a)) 
Mar 15 01:12:26 processing existing schema:  aten::to(Tensor(a) self, bool non_blocking=False, bool copy=False) -> (Tensor(b|a)) 
Mar 15 01:12:26 processing existing schema:  aten::topk(Tensor self, int k, int dim=-1, bool largest=True, bool sorted=True) -> (Tensor values, Tensor indices) 
Mar 15 01:12:26 processing existing schema:  aten::trace(Tensor self) -> (Tensor) 
Mar 15 01:12:26 The PR is introducing backward incompatible changes to the operator library. Please contact PyTorch team to confirm whether this change is wanted or not.  
Mar 15 01:12:26  
Mar 15 01:12:26 Broken ops: [ 
Mar 15 01:12:26 	aten::_linear_packed(Tensor packed_weight, Tensor input) -> (Tensor) 
Mar 15 01:12:26 	aten::_linear_prepack(Tensor weight, Tensor? bias=None, float? output_min=None, float? output_max=None) -> (Tensor) 
Mar 15 01:12:26 	aten::_conv2d_packed(Tensor packed_weight, Tensor input) -> (Tensor) 
Mar 15 01:12:26 	aten::_conv2d_prepack(Tensor weight, Tensor? bias=None, int[2] stride=[1, 1], int[2] padding=[0, 0], int[2] dilation=[1, 1], int groups=1, float? output_min=None, float? output_max=None) -> (Tensor) 
Mar 15 01:12:26 ] 
Mar 15 01:12:26 + cleanup 
Mar 15 01:12:26 + retcode=1 
Mar 15 01:12:26 + set +x 

This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker.

This comment has been revised 90 times.

@yf225 yf225 force-pushed the cpp_rnn_layers branch 23 times, most recently from 6e6cec2 to e71781b Compare March 6, 2020 17:04
@yf225 yf225 changed the title [WIP] C++ RNN layers refactoring C++ RNN / GRU / LSTM layer refactoring Mar 6, 2020
Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yf225 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yf225 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yf225 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@yf225 merged this pull request in e23a9dc.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yf225 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yf225 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@yf225 merged this pull request in bdd7dbf.

@facebook-github-bot
Copy link
Contributor

@yf225 merged this pull request in bdd7dbf.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Merged module: bc-breaking Related to a BC-breaking change module: cpp Related to C++ API

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants