-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Default initial hidden states for recurrent layers : Issue#434 #605
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
torch/nn/modules/rnn.py
Outdated
|
|
||
| def forward(self, input, hx): | ||
| def forward(self, input, hx=None): | ||
| if (hx == None): |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/nn/modules/rnn.py
Outdated
| def forward(self, input, hx): | ||
| def forward(self, input, hx=None): | ||
| if (hx == None): | ||
| batch_sz = input.size()[0] if self.batch_first else input.size()[1] |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/nn/modules/rnn.py
Outdated
| def forward(self, input, hx=None): | ||
| if (hx == None): | ||
| batch_sz = input.size()[0] if self.batch_first else input.size()[1] | ||
| hx = torch.autograd.Variable(torch.Tensor(self.num_layers, batch_sz, |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/nn/modules/rnn.py
Outdated
| if (hx == None): | ||
| batch_sz = input.size()[0] if self.batch_first else input.size()[1] | ||
| hx = torch.autograd.Variable(torch.Tensor(self.num_layers, batch_sz, | ||
| self.input_size).zero_()) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/nn/modules/rnn.py
Outdated
| batch_sz = input.size()[0] if self.batch_first else input.size()[1] | ||
| hx = torch.autograd.Variable(torch.Tensor(self.num_layers, batch_sz, | ||
| self.input_size).zero_()) | ||
| if (self.mode == 'LSTM'): |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/nn/modules/rnn.py
Outdated
| self.input_size).zero_()) | ||
| if (self.mode == 'LSTM'): | ||
| hx = (torch.autograd.Variable(hx.data), | ||
| torch.autograd.Variable(hx.data)) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
One last thing. Can you please add a test that uses this change? Just instantiate one of each kind of RNNs we have and pass a batch through it - once without passing the hidden state, and once with a manually constructed one. Then use |
|
I think I did a merge while trying to rebase my branch - that's why it shows the last commit 722c407. Is this okay? Should I try to revert this or just go ahead with adding the test case? |
|
go ahead and add the testcase. we'll squash it down before merging. |
Add pruning tutorial. Will create another PR to add it into the ToC.
Setting a default initial hidden state of zeros if the hidden state is not provided by the user. Doing this in the RNNBase class, so it works for all three - RNN, GRU and LSTM.