Skip to content

[feature request] Support bidirectional RNNs in C++ API #17998

@mthrok

Description

@mthrok

🐛 Bug

C++ API torch::nn::GRU fails when bidirectional=true. (It works fine when false).

Error occurs here

I tried with different options; the changing the number of layers from 1 to 2 and switching batch_first parameter, but it yielded the same outcome.

To Reproduce

C++ Source

#include <torch/torch.h>

int BATCH_SIZE = 32;
int SEQ_LEN = 50;
int FEAT_DIM = 64;
bool bidirectional = true;  // it works when this is `false`

struct Net : torch::nn::Module {
  torch::nn::GRU gru{nullptr};
  Net() {
    auto opt = torch::nn::GRUOptions(FEAT_DIM, 128);
    opt.bidirectional(bidirectional);
    opt.batch_first(true);
    gru = register_module("gru", torch::nn::GRU(opt));
  }
  torch::Tensor forward(torch::Tensor x) {
      return gru->forward(x).output;
  }
};

int main() {
  auto net = std::make_shared<Net>();
  net->forward(torch::rand({BATCH_SIZE, SEQ_LEN, FEAT_DIM}));
}

CMakeLists.txt

cmake_minimum_required(VERSION 3.0 FATAL_ERROR)
project(example-app)

find_package(Torch REQUIRED)

add_executable(test-app test.cpp)
target_link_libraries(test-app "${TORCH_LIBRARIES}")
set_property(TARGET test-app PROPERTY CXX_STANDARD 11)

Error Message

terminate called after throwing an instance of 'c10::Error'
  what():  Odd number of params or hiddens given to a bidirectional RNN (pair_vec at /pytorch/aten/src/ATen/native/RNN.cpp:135)
frame #0: std::function<std::string ()>::operator()() const + 0x11 (0x7f185aad22b1 in /pytorch-test/libtorch/lib/libc10.so)
frame #1: c10::Error::Error(c10::SourceLocation, std::string const&) + 0x2a (0x7f185aad1bea in /pytorch-test/libtorch/lib/libc10.so)
frame #2: <unknown function> + 0xa4ed11 (0x7f185b739d11 in /pytorch-test/libtorch/lib/libcaffe2.so)
frame #3: at::native::gru(at::Tensor const&, at::Tensor const&, c10::ArrayRef<at::Tensor>, bool, long, double, bool, bool, bool) + 0x272 (0x7f185b7417b2 in /pytorch-test/libtorch/lib/libcaffe2.so)
frame #4: at::TypeDefault::gru(at::Tensor const&, at::Tensor const&, c10::ArrayRef<at::Tensor>, bool, long, double, bool, bool, bool) const + 0xc9 (0x7f185ba1e879 in /pytorch-test/libtorch/lib/libcaffe2.so)
frame #5: torch::autograd::VariableType::gru(at::Tensor const&, at::Tensor const&, c10::ArrayRef<at::Tensor>, bool, long, double, bool, bool, bool) const + 0x344 (0x7f18656c28d4 in /pytorch-test/libtorch/lib/libtorch.so.1)
frame #6: <unknown function> + 0xb03182 (0x7f18659f5182 in /pytorch-test/libtorch/lib/libtorch.so.1)
frame #7: std::_Function_handler<std::tuple<at::Tensor, at::Tensor> (at::Tensor const&, at::Tensor const&, c10::ArrayRef<at::Tensor>, bool, long, double, bool, bool, bool), std::tuple<at::Tensor, at::Tensor> (*)(at::Tensor const&, at::Tensor const&, c10::ArrayRef<at::Tensor>, bool, long, double, bool, bool, bool)>::_M_invoke(std::_Any_data const&, at::Tensor const&, at::Tensor const&, c10::ArrayRef<at::Tensor>, bool, long, double, bool, bool, bool) + 0x34 (0x7f18659f5e14 in /pytorch-test/libtorch/lib/libtorch.so.1)
frame #8: torch::nn::detail::RNNImplBase<torch::nn::GRUImpl>::generic_forward(std::function<std::tuple<at::Tensor, at::Tensor> (at::Tensor const&, at::Tensor const&, c10::ArrayRef<at::Tensor>, bool, long, double, bool, bool, bool)>, at::Tensor const&, at::Tensor) + 0xd9 (0x7f18659f6869 in /pytorch-test/libtorch/lib/libtorch.so.1)
frame #9: torch::nn::GRUImpl::forward(at::Tensor const&, at::Tensor) + 0x56 (0x7f18659f59f6 in /pytorch-test/libtorch/lib/libtorch.so.1)
frame #10: Net::forward(at::Tensor) + 0x59 (0x423071 in ./test-app)
frame #11: main + 0xba (0x42033e in ./test-app)
frame #12: __libc_start_main + 0xf0 (0x7f185a166830 in /lib/x86_64-linux-gnu/libc.so.6)
frame #13: _start + 0x29 (0x41f979 in ./test-app)

Expected behavior

GRU should work when bidirectional=true as well as bidirectional=false.

Environment

libtorch version: latest, downloaded from https://download.pytorch.org/libtorch/nightly/cpu/libtorch-shared-with-deps-latest.zip at Wed Mar 13 21:47:52 UTC 2019

OS: Ubuntu 16.04.5 LTS
GCC version: (Ubuntu 5.4.0-6ubuntu1~16.04.11) 5.4.0 20160609
CMake version: version 3.5.1

Additional context

Metadata

Metadata

Assignees

Labels

module: cppRelated to C++ APImodule: rnnIssues related to RNN support (LSTM, GRU, etc)triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions