-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Closed
Labels
high prioritymodule: higher order operatorstorch.cond and similartorch.cond and similaroncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
I got this repro from noticing some test failures locally (details here: #136670 (comment))
import torch
from torch._higher_order_ops.associative_scan import associative_scan
def f(a):
return associative_scan(lambda x, y: x + y, a, dim=1, reverse=True, combine_mode="generic")
a = torch.arange(18, dtype=torch.float32, device='cuda').reshape(2, 9)
# I got this output from running with `_fake_associative_scan` defined here:
# https://github.com/pytorch/pytorch/blob/main/test/functorch/test_control_flow.py#L1579
expected_out = torch.tensor([
[ 36., 36., 35., 33., 30., 26., 21., 15., 8.],
[117., 108., 98., 87., 75., 62., 48., 33., 17.],
], dtype=torch.float32, device='cuda')
out = f(a)
print(torch.allclose(out, expected_out))
# output is:
# tensor([[ 36., 15., 35., 0., 30., 0., 21., 0., 8.],
[117., 33., 98., 36., 75., 33., 48., 26., 17.]],
device='cuda:0')
cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @chauhang @penguinwu @ydwu4 @yf225
Metadata
Metadata
Assignees
Labels
high prioritymodule: higher order operatorstorch.cond and similartorch.cond and similaroncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module