Skip to content

Conversation

@goelhardik
Copy link
Contributor

Setting a default initial hidden state of zeros if the hidden state is not provided by the user. Doing this in the RNNBase class, so it works for all three - RNN, GRU and LSTM.

@goelhardik goelhardik changed the title Default initial hidden states for recurrent layers #434 Default initial hidden states for recurrent layers : Issue#434 Jan 27, 2017

def forward(self, input, hx):
def forward(self, input, hx=None):
if (hx == None):

This comment was marked as off-topic.

def forward(self, input, hx):
def forward(self, input, hx=None):
if (hx == None):
batch_sz = input.size()[0] if self.batch_first else input.size()[1]

This comment was marked as off-topic.

def forward(self, input, hx=None):
if (hx == None):
batch_sz = input.size()[0] if self.batch_first else input.size()[1]
hx = torch.autograd.Variable(torch.Tensor(self.num_layers, batch_sz,

This comment was marked as off-topic.

if (hx == None):
batch_sz = input.size()[0] if self.batch_first else input.size()[1]
hx = torch.autograd.Variable(torch.Tensor(self.num_layers, batch_sz,
self.input_size).zero_())

This comment was marked as off-topic.

batch_sz = input.size()[0] if self.batch_first else input.size()[1]
hx = torch.autograd.Variable(torch.Tensor(self.num_layers, batch_sz,
self.input_size).zero_())
if (self.mode == 'LSTM'):

This comment was marked as off-topic.

self.input_size).zero_())
if (self.mode == 'LSTM'):
hx = (torch.autograd.Variable(hx.data),
torch.autograd.Variable(hx.data))

This comment was marked as off-topic.

@apaszke
Copy link
Contributor

apaszke commented Jan 27, 2017

One last thing. Can you please add a test that uses this change? Just instantiate one of each kind of RNNs we have and pass a batch through it - once without passing the hidden state, and once with a manually constructed one. Then use self.assertEqual to compare them and make sure that it works as we want. Thanks!

@goelhardik
Copy link
Contributor Author

I think I did a merge while trying to rebase my branch - that's why it shows the last commit 722c407. Is this okay? Should I try to revert this or just go ahead with adding the test case?

@soumith
Copy link
Contributor

soumith commented Jan 29, 2017

go ahead and add the testcase. we'll squash it down before merging.

@apaszke apaszke merged commit 956d946 into pytorch:master Jan 29, 2017
@goelhardik goelhardik deleted the issue-434 branch February 20, 2017 02:24
jeffdaily pushed a commit to jeffdaily/pytorch that referenced this pull request Mar 20, 2020
mrshenli pushed a commit to mrshenli/pytorch that referenced this pull request Apr 11, 2020
Add pruning tutorial. Will create another PR to add it into the ToC.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants