-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Closed
Labels
module: autogradRelated to torch.autograd, and the autograd engine in generalRelated to torch.autograd, and the autograd engine in generalmodule: nnRelated to torch.nnRelated to torch.nnmodule: testsIssues related to tests (not the torch.testing module)Issues related to tests (not the torch.testing module)triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
🐛 Bug
gradgradcheck fails for torch.nn.functional.cross_entropy in edge case.
To Reproduce
import torch
from torch.autograd import gradgradcheck
from torch.nn.functional import cross_entropy
device = "cpu"
torch.manual_seed(0)
input = torch.randn((1, 2), device=device, dtype=torch.float64).requires_grad_(True)
target = torch.randint(0, 2, (1,), device=device, dtype=torch.int64)
weight = torch.tensor([1.0, -1.0], device=device, dtype=torch.float64)
gradgradcheck(
lambda input, target: cross_entropy(input, target, weight=weight),
(input, target),
)GradcheckError: Jacobian mismatch for output 0 with respect to input 1,
numerical:tensor([[0., 0.]], dtype=torch.float64)
analytical:tensor([[ 0.8623, -0.8623]], dtype=torch.float64)
The failure is only visible, if
- the snippet is run on the CPU,
targetare not probabilities (target = torch.rand_like(input)), andweightcontains at least one negative value.
Additional Context
This was detected in #63547 while adding the OpInfo for torch.nn.functional.cross_entropy.
cc @ezyang @albanD @zou3519 @gqchen @pearu @nikitaved @soulitzer @lezcano @Varal7 @mruberry @jbschlosser @walterddr
krshrimali
Metadata
Metadata
Assignees
Labels
module: autogradRelated to torch.autograd, and the autograd engine in generalRelated to torch.autograd, and the autograd engine in generalmodule: nnRelated to torch.nnRelated to torch.nnmodule: testsIssues related to tests (not the torch.testing module)Issues related to tests (not the torch.testing module)triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module