Skip to content

Conversation

@bohnstingl
Copy link
Collaborator

@bohnstingl bohnstingl commented Sep 29, 2024

This is part of a series of PRs to improve the functionality of the associatve_scan functionality. This specific PR implements the Autograd for associative_scan. This PR has been derived from #129307.

@ydwu4

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @rec

@bohnstingl bohnstingl requested a review from zou3519 as a code owner September 29, 2024 22:42
@pytorch-bot
Copy link

pytorch-bot bot commented Sep 29, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/136966

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

❌ 15 New Failures

As of commit fc19798 with merge base a16476b (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@zou3519 zou3519 requested review from ydwu4 and removed request for zou3519 September 30, 2024 19:48
@zou3519 zou3519 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Sep 30, 2024
@bohnstingl bohnstingl force-pushed the generic_associative_scan_6 branch from a9366e6 to 3ab2330 Compare October 2, 2024 08:37
@bohnstingl
Copy link
Collaborator Author

@pytorchbot label "topic: not user facing"

@pytorch-bot pytorch-bot bot added the topic: not user facing topic category label Oct 17, 2024
@bohnstingl bohnstingl force-pushed the generic_associative_scan_6 branch from 15de9c4 to 8d669c1 Compare October 18, 2024 17:47
@bohnstingl bohnstingl force-pushed the generic_associative_scan_6 branch from 8d669c1 to fc19798 Compare October 22, 2024 23:11

results = [
torch.stack([e[leave_ind] for e in op(result_flat)], dim)
torch.concatenate([e[leave_ind] for e in op(result_flat)], dim)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is it a concatenate? if we associative_scan over (4, 2, 3) over dim=0, each subgraph should work on a slice of (2, 3), and the end results should be of shape (4, 2, 3). Anything wrong with this interface? After the change, does subgraph takes (1, 2, 3) or result becomes (8, 3).

Copy link
Contributor

@ydwu4 ydwu4 Oct 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, this one should be deleted i feel? _fake_scan is now in _higher_order_ops/scan.py.


def arg_extractor(combine_fn, xs, dim):
return combine_fn, xs, dim
def arg_extractor(combine_fn, xs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i feel we should split this diff into 2, first is the the dim change, then the autograd.

dim = utils.canonicalize_dim(ndim, dim)
# Move scan dim to 0 and always perform scan on dim 0
orig_scan_dim = dim
leaves = [shift_source_dim_to_target_dim(elem, int(dim), 0) for elem in leaves]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we might replace shift_source_dim_to_target_dim with torch.movedim(elem, int(dim), 0).

result_flat = [torch.flip(elem, [0]) for elem in result_flat]

result_flat = [
shift_source_dim_to_target_dim(elem, 0, orig_scan_dim) for elem in result_flat
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shift_source_dim_to_target_dim -> movedim

return pytree.tree_unflatten(result_flat, spec)


# TODO: Provide inductor support for generic scan
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's missing in inductor for generic scan. Is it the test failure we talked about?

return (*outs,)

@staticmethod
def backward(ctx, *flat_grads_unmasked):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I trust you on this.

Didn't look into the details of the backward implementation. Some general thoughts for better testing this: can we use scan to implement a baseline version first and add much more tests to verify the correctness (e.g. nesting cond, scan and associative scan with autograd, more types of ops inside the body of associative scan(e.g. different kinds of view ops, non-continous inputs and outputs.)

pytorchmergebot pushed a commit that referenced this pull request Nov 5, 2024
In this PR, the combine_fn is consistently called with a slice along the scan dim. It implements part of #136966

Pull Request resolved: #138858
Approved by: https://github.com/ydwu4
@bohnstingl
Copy link
Collaborator Author

Closing this PR, as it is split into several smaller PRs: #138858, #139864, #139939

@bohnstingl bohnstingl closed this Nov 6, 2024
pobin6 pushed a commit to pobin6/pytorch that referenced this pull request Dec 5, 2024
In this PR, the combine_fn is consistently called with a slice along the scan dim. It implements part of pytorch#136966

Pull Request resolved: pytorch#138858
Approved by: https://github.com/ydwu4
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

module: dynamo module: inductor open source topic: not user facing topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants