Skip to content

torch.stack() gradient errors in 0.4.1 #9977

@fritzo

Description

@fritzo

Issue description

We're seeing Pyro code breakage due to interaction of torch.stack and binary_cross_entropy_with_logits as used inside torch.Bernoulli.log_prob().

Possibly related: #4274

Code examples

The following examples were distilled from Pyro's tests/contrib/tracking/test_em.py:

x = torch.zeros(1, requires_grad=True)
y = torch.zeros(1)
stacked = torch.stack([y, x], -1)
loss = binary_cross_entropy_with_logits(stacked[:, 1], y, reduction='none')
g = torch.autograd.grad(loss, [x], create_graph=True)[0]
H = torch.autograd.grad(g.sum(), [x], create_graph=True)[0]
RuntimeError: dim() called on undefined Tensor
x = torch.zeros(1, requires_grad=True)
y = torch.zeros(1)
stacked = torch.stack([x, y], -1)
loss = binary_cross_entropy_with_logits(stacked[:, 0], y, reduction='none')
g = torch.autograd.grad(loss, [x], create_graph=True)[0]
H = torch.autograd.grad(g.sum(), [x], create_graph=True)[0]
RuntimeError: Expected a Tensor of type Variable but found an undefined Tensor at position #1 for iterable argument #0 'tensors'

These are both equivalent to the following unstacked version which works fine:

x = torch.zeros(1, requires_grad=True)
y = torch.zeros(1)
loss = binary_cross_entropy_with_logits(x, y, reduction='none')
g = torch.autograd.grad(loss, [x], create_graph=True)[0]
H = torch.autograd.grad(g.sum(), [x], create_graph=True)[0]

System info

PyTorch version: 0.4.1
Is debug build: No
CUDA used to build PyTorch: None

OS: Mac OSX 10.13.3
GCC version: Could not collect
CMake version: version 3.9.6

Python version: 2.7
Is CUDA available: No
CUDA runtime version: No CUDA
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA

Versions of relevant libraries:
[pip] Could not collect
[conda] torch                     0.4.1                     <pip>
[conda] torchfile                 0.1.0                     <pip>
[conda] torchvision               0.2.1                     <pip>

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions