Skip to content

Conversation

@zhxchen17
Copy link
Contributor

@zhxchen17 zhxchen17 commented Jan 10, 2023

Fixes #ISSUE_NUMBER

We want to add support for control flow map() at dynamo level to unblock some internal model which will have to use map() operator in captured graph. Basically I replicate the pattern for implementing cond() op from #90286

cc @mlazos @soumith @voznesenskym @yanboliang @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @chunyuan-w @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @desertfire

@pytorch-bot
Copy link

pytorch-bot bot commented Jan 10, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/91939

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 3cfcad6:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@zhxchen17 zhxchen17 changed the title [dynamo] Support map() operator. [dynamo] Support control flow map() operator. Jan 10, 2023
@ezyang ezyang requested a review from zou3519 January 10, 2023 22:16
tx.output.register_attr_or_module(gm, next_name, source=src)
return next_name

def make_subgraph(f, sub_args, graph_checkpoint, checkpoint):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should have "speculate" in its name as it doesn't actually apply the changes to Dynamo state

body_nn_modules,
body_cmp,
) = make_subgraph(args[0], [
wrap_fx_proxy(tx, args[1].as_proxy()[0], **VariableTracker.propagate(args[1])),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks suspicious. What's going on here.

"body", torch.fx.GraphModule(body_nn_modules, body_graph)
)

# Apply side effects (guaranteed to be equal)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this comment is out of date, you only ran one branch

*(arg.as_proxy() for arg in args[1:])
)
r = body_r.as_proxy().node.meta["example_value"]
example_value = r.new_empty([get_fake_value(args[1].as_proxy().node, tx).shape[0], *r.shape])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks wrong.

Copy link
Contributor

@ezyang ezyang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Broadly speaking, the implementation here looks like it's cargo culted from cond's implementation, but this is not really appropriate, because map is quite different from cond in two respects:

  1. There's only one lambda, but it can get run multiple times
  2. The input/output of the lambda diverge from the outer context

(1) means that you need to, for example, assert that there are no side effects in the lambda (or that the side effects are idempotent or something). Suppose that inside the lambda you call nonlocal x; x += 1. The eager mode map semantics will increment this counter every iteration of the loop. But what you have implemented here applies the side effects once, and then is done. You will have the wrong semantics in this case.

(2) appears to be what is going on with the weirdness with the making fx proxy tensor and then the stuff with new_empty to make the example value. I can believe that the implementation as is works in most cases, but it needs to be documented far better. In particular, when you say x[0] you are leaning on the fact that make_subgraph doesn't actually use the passed in proxies in an interesting way; instead, converts them into nodes. Otherwise, it would be incorrect to say that you had called the lambda with x[0]. In fact, I think I might prefer that we not pass in a proxy (except maybe to make the names nicer) because the implementation here will not work if you map over a tensor with size (0, *sizes) (since x[0] will fail in this case.)

I will expect tests for all of these edge cases.

@zhxchen17
Copy link
Contributor Author

zhxchen17 commented Jan 12, 2023

Broadly speaking, the implementation here looks like it's cargo culted from cond's implementation, but this is not really appropriate, because map is quite different from cond in two respects:

  1. There's only one lambda, but it can get run multiple times
  2. The input/output of the lambda diverge from the outer context

(1) means that you need to, for example, assert that there are no side effects in the lambda (or that the side effects are idempotent or something). Suppose that inside the lambda you call nonlocal x; x += 1. The eager mode map semantics will increment this counter every iteration of the loop. But what you have implemented here applies the side effects once, and then is done. You will have the wrong semantics in this case.

(2) appears to be what is going on with the weirdness with the making fx proxy tensor and then the stuff with new_empty to make the example value. I can believe that the implementation as is works in most cases, but it needs to be documented far better. In particular, when you say x[0] you are leaning on the fact that make_subgraph doesn't actually use the passed in proxies in an interesting way; instead, converts them into nodes. Otherwise, it would be incorrect to say that you had called the lambda with x[0]. In fact, I think I might prefer that we not pass in a proxy (except maybe to make the names nicer) because the implementation here will not work if you map over a tensor with size (0, *sizes) (since x[0] will fail in this case.)

I will expect tests for all of these edge cases.

@ezyang All these makes a lot of sense, thanks!
So my basic assumption here is that we are not going to support loop bodies that have side effect, because we haven't saw real use cases anyway, but it's true that I need to check for those.
Just want to make sure I understand point (2) correctly:
Instead of passing x[0] and calling the lambda, we could just construct a new sample value which is not related to the the current execution context, and then trace the lambda (assuming it works better for (0, *sizes) shape)?
In this case what we could do to provide an example output for the whole torch.map()? I guess we still need to construct example value based on example value of the inner body graph.

@ezyang
Copy link
Contributor

ezyang commented Jan 12, 2023

It's easy to say "oh, I will just check that there are no side effects in the loop body" but Dynamo does model some operations (like accessing closed over variables) as side effects. In any case, try asserting no side effects and see if it errors or not on the models you care about.

Re (2), you understand correctly. You actually point out a good point though, which is that zero size input cannot work anyway, as you MUST run the lambda to actually get the output shape, but the lambda is unrunnable if you have no samples. So I guess map() as defined here cannot work with zero size, and so maybe accessing the first element is fine. Better add a check wrt though...

@zhxchen17 zhxchen17 force-pushed the zhxchen17/control_flow/2 branch 2 times, most recently from 43707dc to 2c42b59 Compare January 18, 2023 21:45
@zhxchen17
Copy link
Contributor Author

Sorry for the late update. Addressed comments with a few updates:

  1. Check for scalar / zero sized tensor for map() during tracing. Added a unit test case.
  2. Check for extra pending side effects from calling map() body, currently we just throw unsupported error if map() has any side effect. Added a unit test case.
  3. Nits: added/removed comments about sample inputs for map(). renamed make_subgraph to speculate_subgraph.

cc @ezyang

@zhxchen17 zhxchen17 requested a review from ezyang January 18, 2023 21:49
@zhxchen17 zhxchen17 added the topic: not user facing topic category label Jan 18, 2023
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I hope this access doesn't induce a write to the FX graph

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The rest of body_cmp is ignored here. Maybe it is safer to factor out the comparable state into its own function, and then you can get the comparable state prior to running the lambda and then do a full comparison there.

@zhxchen17 zhxchen17 force-pushed the zhxchen17/control_flow/2 branch from 2c42b59 to 3cfcad6 Compare January 19, 2023 08:10
@zhxchen17
Copy link
Contributor Author

updates:

  • Construct a TensorVariable with inner graph proxy directly for the sample input xs[0], so that we don't insert an extra getitem node into the parent graph.
  • Factor out the comparable state for original graph state, and do a full comparison between the original graph state and the loop body graph state.

@zhxchen17
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Jan 19, 2023
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

The merge job was canceled. If you believe this is a mistake,then you can re trigger it through pytorch-bot.

@zhxchen17
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: Not merging any PRs at the moment because there is a merge blocking https://github.com/pytorch/pytorch/labels/ci:%20sev issue open at:
#92626

Details for Dev Infra team Raised by workflow job

@zhxchen17
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@github-actions github-actions bot deleted the zhxchen17/control_flow/2 branch July 20, 2024 01:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants