Skip to content

batch_first broken in AutogradRNN #253

@jekbradbury

Description

@jekbradbury

The last line here fails on CPU or when CUDNN is otherwise unavailable:

l, b, t, x, h = 2, 3, 5, 10, 20

rnn = nn.LSTM(x, h, l, batch_first=True)
inpt = Variable(torch.randn(b, t, x))
h0 = Variable(torch.randn(l, b, h))
c0 = Variable(torch.randn(l, b, h))
output, hn = rnn(inpt, (h0, c0))

This is because AutogradRNN.forward accidentally assumes Tensor's in-place transpose semantics rather than the functional semantics of Variable (cudnn.rnn.forward gets it right):

def forward(input, weight, hidden):
    if batch_first:
        input.transpose(0, 1)
    nexth, output = func(input, hidden, weight)
    if batch_first:
        output.transpose(0, 1)

I can push a PR that fixes this, or one of the devs can put it in the next bugfix PR:

def forward(input, weight, hidden):
    if batch_first:
        input = input.transpose(0, 1)
    nexth, output = func(input, hidden, weight)
    if batch_first:
        output = output.transpose(0, 1)

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions