-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[C++ API] RNN / GRU / LSTM layer refactoring #34322
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
💊 CircleCI build failures summary and remediationsAs of commit 358fc46 (more details on the Dr. CI page):
🕵️ 1 new failure recognized by patternsThe following build failures do not appear to be due to upstream breakages:
|
6e6cec2 to
e71781b
Compare
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yf225 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Co-Authored-By: Pavel Belevich <[email protected]>
Co-Authored-By: Pavel Belevich <[email protected]>
Co-Authored-By: Pavel Belevich <[email protected]>
Co-Authored-By: Pavel Belevich <[email protected]>
Co-Authored-By: Pavel Belevich <[email protected]>
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yf225 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yf225 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yf225 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yf225 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
This PR refactors RNN / GRU / LSTM layers in C++ API to exactly match the implementation in Python API.
BC-breaking changes:
RNNOutput, RNN / GRU forward method now returnsstd::tuple<Tensor, Tensor>, and LSTM forward method now returnsstd::tuple<Tensor, std::tuple<Tensor, Tensor>>, matching Python API.forward_with_packed_inputmethod which acceptsPackedSequenceas input and optionally hidden state, matching theforward(PackedSequence, ...)variant in Python API.w_ih/w_hh/b_ih/b_hh. Instead, to access the weights and biases of the gates, users should do e.g.rnn->named_parameters()["weight_ih_l0"], which mirrors the Python APIrnn.weight_ih_l0.RNNOptionstanh()/relu()/activationare removed. Instead,nonlinearityis added which takes eithertorch::kTanhortorch::kReLUlayers->num_layerswith_bias->biasLSTMOptionslayers->num_layerswith_bias->biasGRUOptionslayers->num_layerswith_bias->biasThe majority of the changes in this PR focused on refactoring the implementations in
torch/csrc/api/src/nn/modules/rnn.cppto match the Python API. RNN tests are then changed to reflected the revised API design.