Skip to content

associative scan is incorrect for certain shapes/kwargs #137943

@bdhirsh

Description

@bdhirsh

I got this repro from noticing some test failures locally (details here: #136670 (comment))

import torch
from torch._higher_order_ops.associative_scan import associative_scan

def f(a):
    return associative_scan(lambda x, y: x + y, a, dim=1, reverse=True, combine_mode="generic")

a = torch.arange(18, dtype=torch.float32, device='cuda').reshape(2, 9)

# I got this output from running with `_fake_associative_scan` defined here:
# https://github.com/pytorch/pytorch/blob/main/test/functorch/test_control_flow.py#L1579
expected_out = torch.tensor([
    [ 36.,  36.,  35.,  33.,  30.,  26.,  21.,  15.,   8.],
    [117., 108.,  98.,  87.,  75.,  62.,  48.,  33.,  17.],
], dtype=torch.float32, device='cuda')

out = f(a)
print(torch.allclose(out, expected_out))
# output is:
# tensor([[ 36.,  15.,  35.,   0.,  30.,   0.,  21.,   0.,   8.],
          [117.,  33.,  98.,  36.,  75.,  33.,  48.,  26.,  17.]],
         device='cuda:0')

cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @chauhang @penguinwu @ydwu4 @yf225

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions