-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[aotd] Do not force contiguous() for channels_last #135225
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/135225
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 2220725 with merge base ff2360c ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
[ghstack-poisoned]
[ghstack-poisoned]
test/functorch/test_aotdispatch.py
Outdated
| memory_format=torch.channels_last | ||
| ), | ||
| ) | ||
| [inp.retain_grad() for inp in ret] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why are all of the retain_grad calls necessary throughout these tests?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I am not doing grad assert checks here, we can remove it.
Original Issue: #134644 We assume trace_tangents to have the same memory_format as inputs, outputs, intermediate during first tracing. => Tracing time: - Store trace_tangents_memory_formats in metadata - Coerce tangents to deduced memory_format Runtime: - Coerce tangents to tracing memory format from metadata Testing ``` python test/functorch/test_aotdispatch.py -k test_channels_last_grads_no_force_contiguous ``` [ghstack-poisoned]
test/functorch/test_aotdispatch.py
Outdated
| out = torch.compile(fn, backend="aot_eager", fullgraph=True)(inp) | ||
| self.assertEqual(ref_out, out) | ||
|
|
||
| def test_channels_last_grads_no_force_contiguous_dense(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would add some more tests with some outputs that are channels last, and some outputs that aren't.
| # Today, we force this guess to be correct by additioanlly calling contiguous() | ||
| # on all tangents at runtime. | ||
| # In the future, you could imagine lifting this restriction, since these contiguous() | ||
| # calls can have noticeable perf overhead depending on the model. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
update this comment
| suggest_memory_format = torch._prims_common.suggest_memory_format | ||
| if is_traceable_wrapper_subclass(t): | ||
| return [ | ||
| recursive_suggest_memory_format(getattr(t, attr)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm a bit torn because:
(1) right now you are trying to track the memory formats of inner tensors separately from their outer subclasses. This technically isn't enough, because you need to handle nested subclasses (what if getattr(t, attr) is another tensor subclass that you need to recursively flatten and get the memory formats of its inner tensors?)
(2) you could fix that by having some recursive nested lists that contain the full hierarchy of suggested memory formats for every subclass tangent. I'm a bit worried about traversing those nested lists being bad for hot-path in the backward.
Alternatively, we could make the simplifying assumption that if a subclass advertises as "channels-last contiguous", it will likely also have channels-last-contiguous inner tensors (and vis versa), and therefore only bother to do this bookkeeping on the top-level subclass and not on the inner tensors.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess it is totally possible to construct a subclass that violates that, though, and the recursive tracking shouldn't be that slow. So what do you think of:
(1) do the recursive tracking
(2) run a quick microbenchmark that times the backward of a function like this, and make sure its backward isn't a lot slower
def create_inp(x):
return TwoTensor(TwoTensor(x.clone(), x.clone()), TwoTensor(x.clone(), x.clone()))
@torch.compile
def f(*args):
return args
x = torch.randn(1, requires_grad=True)
inps = [create_inp(x) for _ in range(100)]
outs = f(inps)
# you probably just want to measure the overhead of the `CompiledFunction.backward()` directly, so you don't measure overhead from the rest of autograd
sum(outs).sum().backward()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, added recursive handling for subclasses attributes.
Will add test for NestedSubclasses to verify it.
Optimization of recursive calls here is interesting. The most expensive will be doing python calls, potentially we can move this DFS of subclasses into C++ fun, but need to measure the difference.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Benchmarked:
benchmark_inps = [inps_fn(torch.randn(2, 3, 5, 5, requires_grad=True).to(memory_format=torch.channels_last)) for _ in range(100)]
bwd_total_duration = 0
for inps in benchmark_inps:
outs = torch.compile(mc, backend="aot_eager", fullgraph=True)(*inps)
s = outs[0].sum()
time_start = time.time()
s.backward()
bwd_duration = time.time() - time_start
bwd_total_duration += bwd_duration
avg_bwd_duration = bwd_total_duration / len(benchmark_inps)
print(f"XXX SUBCLASS_GROUP avg_bwd_duration:{avg_bwd_duration*1000} ms")
class M2(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv = torch.nn.Conv2d(3, 3, 3)
def forward(self, x0, x1, x2, x3):
return self.conv(x0), self.conv(x1), self.conv(x2), self.conv(x3)
m2 = M2()
m2.to(memory_format=torch.channels_last)
m2.train()
def inps_fn2(x):
return (
x.clone(),
x.clone(),
x.clone(),
x.clone()
)
benchmark_inps2 = [inps_fn2(torch.randn(2, 3, 5, 5, requires_grad=True).to(memory_format=torch.channels_last)) for _ in range(100)]
bwd_total_duration = 0
for inps in benchmark_inps2:
outs = torch.compile(m2, backend="aot_eager", fullgraph=True)(*inps)
s = outs[0].sum()
time_start = time.time()
s.backward()
bwd_duration = time.time() - time_start
bwd_total_duration += bwd_duration
avg_bwd_duration = bwd_total_duration / len(benchmark_inps)
print(f"XXX NO_SUBCLASS_GROUP avg_bwd_duration:{avg_bwd_duration * 1000} ms")
XXX SUBCLASS_GROUP avg_bwd_duration:7.817392349243163 ms
XXX NO_SUBCLASS_GROUP avg_bwd_duration:3.940465450286865 ms
Subclasses bwd is 2 times slower avg
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh that delta between subclass and no delta is a good datapoint (that's also probably telling us that other things like the subclass flatten/unflatten we do in the subclass path adds up to quite a bit of overhead in general?)
My specific question though was more about if this PR causes any regressions for the existing subclass path. So "SUBCLASS_GROUP" (with your changes) vs. "SUBCLASS_GROUP" (without your changes)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Measured:
Without changes:
channels_last: 11.92ms
With changes:
contiguous: 10.68ms
channels_last: 8ms
No regression with changes, used all-contiguous() to remove the difference of having contiguous() without-changes. But still do not understand fully why it speeds up the backward in case of all contiguous.
Original Issue: #134644 We assume trace_tangents to have the same memory_format as inputs, outputs, intermediate during first tracing. => Tracing time: - Store trace_tangents_memory_formats in metadata - Coerce tangents to deduced memory_format Runtime: - Coerce tangents to tracing memory format from metadata Testing ``` python test/functorch/test_aotdispatch.py -k test_channels_last_grads_no_force_contiguous ``` [ghstack-poisoned]
Original Issue: #134644 We assume trace_tangents to have the same memory_format as inputs, outputs, intermediate during first tracing. => Tracing time: - Store trace_tangents_memory_formats in metadata - Coerce tangents to deduced memory_format Runtime: - Coerce tangents to tracing memory format from metadata Testing ``` python test/functorch/test_aotdispatch.py -k test_channels_last_grads_no_force_contiguous ``` [ghstack-poisoned]
Original Issue: #134644 We assume trace_tangents to have the same memory_format as inputs, outputs, intermediate during first tracing. => Tracing time: - Store trace_tangents_memory_formats in metadata - Coerce tangents to deduced memory_format Runtime: - Coerce tangents to tracing memory format from metadata Testing ``` python test/functorch/test_aotdispatch.py -k test_channels_last_grads_no_force_contiguous ``` [ghstack-poisoned]
| torch.randn(2, 3, 5, 5, requires_grad=True).to( | ||
| memory_format=torch.channels_last | ||
| ), | ||
| torch.randn(2, 3, 5, 5, requires_grad=True).to( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
even better, let's make the inner two tensors have a mix of channels-first-contiguous and channels-last-contiguous, and ensure that we still don't have any runtime contiguous calls (so we "guessed all of our memory formats" properly)
Original Issue: #134644 We assume trace_tangents to have the same memory_format as inputs, outputs, intermediate during first tracing. => Tracing time: - Store trace_tangents_memory_formats in metadata - Coerce tangents to deduced memory_format Runtime: - Coerce tangents to tracing memory format from metadata Testing ``` python test/functorch/test_aotdispatch.py -k test_channels_last_grads_no_force_contiguous ``` [ghstack-poisoned]
|
|
||
| # Workaround of https://github.com/pytorch/pytorch/issues/62027 | ||
| # tensor.contiguous() guarantees to return non-zero sorted-stride, | ||
| # tensor.to(memory_format=torch.contiguous_format) can keep zero strides. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm... can you talk more about why you think it's worth trying to handle the zero strides case?
The common case where I imagine zero strides showing up is when you compile a model (not including the loss), and then run .sum().backward(). The autograd engin will do something like torch.ones(1).expand(tangent_shape), and the incoming tangent input to the backward will have a zero stride.
We are still kind of out-of-luck in this case though, because the strides of the tangent will be different than the strides of the forward graph output (the tangent technically has overlapping memory, while the forward output does not).
We will have been forced to trace out a backward graph ahead-of-time that assumed that the tangent was a plain contiguous tensor, so it does kind of feel like we are forced to emit the contiguous call in the backward.
Is the zero-stride-handling here attempting to handle that case? Or something else
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Copy from chat:
Inductor generates asserts for sizes and strides, and as output strides were non-zero - then it generates non-zero assert.
When at runtime tangents come with zero stride and .to(contiguous) does not change them - it fails on inductor assert.
Original Issue: #134644 We assume trace_tangents to have the same memory_format as inputs, outputs, intermediate during first tracing. => Tracing time: - Store trace_tangents_memory_formats in metadata - Coerce tangents to deduced memory_format Runtime: - Coerce tangents to tracing memory format from metadata Testing ``` python test/functorch/test_aotdispatch.py -k test_channels_last_grads_no_force_contiguous ``` [ghstack-poisoned]
| ): | ||
| updated = False | ||
| if not isinstance(x, Tensor): | ||
| return x, None, updated |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this fn feels a bit complicated (it returns [Tensor, MemoryFormat, bool]?) - can you add return types to the fn and document why they are needed / how they are used? (in particular it's not obvious to me why we need to return a was_updated bool)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
reading this more - I think returning the was_updated bool is kind of confusing, since it is ignored in most places. What if instead, we:
- don't return a
was_updatedbool - in the places where the caller needed it, they can just do the check themselves to know if contiguous had an effect:
updated_out, mem_format = coerce_tangent_and_suggest_memory_format(out)
if updated_out is not out:
setattr(...)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Originally I had updated_out is not out but this does not work, as we do out.detach() first, and the check is not will always fail and we will do the update.
I introduced was_updated to avoid additional check on memory_format, which can be painful with symbolic shapes and add some guards.
| if keep_arg_mask[m.mutated_inp_runtime_indices[i]] | ||
| ] | ||
| traced_tangents = filtered_inp_traced_tangents + other_traced_tangents | ||
| assert m.traced_tangent_memory_formats is not None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
side note (since I see the logic you're changing here is in the remove_dupe_metadata codepath): we are pretty sure that this is dead code in torch.compile, since dynamo has its own logic to remove dupe'd inputs.
@jamesjwu tried to kill this code a while back (but I think ran into some issues?) - we probably want to give it another shot at some point #127306
| assert m.traced_tangent_memory_formats is not None | ||
| traced_tangent_memory_formats = [torch.contiguous_format] * len( | ||
| filtered_inp_traced_tangents | ||
| ) + m.traced_tangent_memory_formats[num_data_mutations:] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
meh yeah i guess this is good enough (not 100% accurate, but again this should be dead code)
| def coerce_tangent(x): | ||
| # If runtime specfied tangents will not have the same memory format as predicted traced tangents, | ||
| # we coerce them at runtime to traced tangents memory format. | ||
| def coerce_tangent(x, memory_format=torch.contiguous_format): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm i don't think I understand - why do we have both coerce_tangent() and coerce_tangent_and_suggest_memory_format() floating around? It seems like your new util should subsume the old one (and we can replace call sites of the first with the second?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I introduced a new one to not regress the coerce_tangent that is also called from input_and_mutation_aliases.
coerce_tangent_and_suggest_memory_format is doing more things calling suggest_memory_format and optional logic of force_memory_format.
So I decided to keep the fast version in parallel. (It's callsed from input_output_analysis)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, the only place I see it used in input_output_analysis.py is in create_synthetic_base_metadata, which is called at compile time anyway. So I don't think this fn is actually used anywhere at runtime?
| # Coercing and collecting traced tangents memory format in one recursive traversal | ||
| # mypy: ignore-errors | ||
| def coerce_tangent_and_suggest_memory_format( | ||
| x: Tensor, force_memory_format: Optional[torch.memory_format] = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm... do we actually pass in a non-None value for force_memory_format anywhere? I don't think I see one.
Given that we have a separate coercion function for trace-time vs runtime, the "force_memory_format" seems unnecessary? (the trace time fn never forces, the runtime fn always forces)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, non-None value means that we need to run suggest_memory_formt to deduce it, this happens for all output tangents. While for others - aliases, mutation input tangents - it will have a value.
| ] | ||
| all_args = [ | ||
| ( | ||
| AOTDispatchAutograd.coerce_runtime_tangent_tracing_memory_format( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm why do we need to call coerce_runtime_tangent_tracing_memory_format separately on both sides of the branch above?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the subclasses case we want to call it before flattening. Potentially we can move coercing before subclasses branch I will try.
| # if the tangent is a subclass, traced_tangent_memory_formats[i] holds a list of memory formats, | ||
| # containing the expected memory format of the subclass **and** all of its inner tensors | ||
| traced_tangent_memory_formats: Optional[ | ||
| List[Union[torch.memory_format, List[torch.memory_format]]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
reading the contents of coerce_tangent_and_suggest_memory_format, I think this type is actually not accurate? It looks like when you have nested layers of tensor subclasses, the inner list is not flattened. So if I have 3 layers of wrapped TwoTensor, I'll get 3 layers of nested lists.
Is that intentional? (it might make implementing the logic cleaner). If it is, then I would just kill the typing here and explain it in a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this typing is wrong, we have here recursive typing
TMF = Union[torch.memory_format, List[TMF]]
traced_tangent_memory_formats: Optional[List[TMF]]
will try to express it with typing, or replace with Any and put the recursive typing in comment
Original Issue: #134644 We assume trace_tangents to have the same memory_format as inputs, outputs, intermediate during first tracing. => Tracing time: - Store trace_tangents_memory_formats in metadata - Coerce tangents to deduced memory_format Runtime: - Coerce tangents to tracing memory format from metadata Subclasses logic: - Previously coercing tangents logic did not handle nested subclasses case, fixing this. For Subclasses we deduce memory format for subclass_tensor first, then for each element of subclass: [subclass_tensor_memory_format, subclass_tensor_elem0_memory_format, ... ] If subclass element (__tensor_flatten__[0] tensors) is also subclass => on its place we will have a nested list of the same structure. The recursive traversal of subclass tree is expensive. So we do memory format deduction and coercing at the same time, to keep only one traverse for this. With this approach there is no regression in comparison with previous logic which also does one traversal. (`coerce_tangent_and_suggest_memory_format` method). Other small change: Remove duplicated not-related comment. Testing ``` python test/functorch/test_aotdispatch.py -k test_channels_last_grads_no_force_contiguous ``` Benchmarking: After change: ``` └─ $ PYTORCH_AOTD_DEBUG_PROFILE=1 python test/functorch/test_aotdispatch.py -k test_benchmark_grads_no_force_contiguous Benchmark SUBCLASS avg_bwd_duration:4.059906005859375 ms Benchmark NO_SUBCLASS avg_bwd_duration:3.1563830375671387 ms ``` Before change: ``` BEFORE_CHANGE SUBCLASS 4.1194 ``` No siginificant changes in processing time. (We do single traverse of subclass tree for collecting memory_formats and coercing during tracing.) [ghstack-poisoned]
Original Issue: #134644 We assume trace_tangents to have the same memory_format as inputs, outputs, intermediate during first tracing. => Tracing time: - Store trace_tangents_memory_formats in metadata - Coerce tangents to deduced memory_format Runtime: - Coerce tangents to tracing memory format from metadata Subclasses logic: - Previously coercing tangents logic did not handle nested subclasses case, fixing this. For Subclasses we deduce memory format for subclass_tensor first, then for each element of subclass: [subclass_tensor_memory_format, subclass_tensor_elem0_memory_format, ... ] If subclass element (__tensor_flatten__[0] tensors) is also subclass => on its place we will have a nested list of the same structure. The recursive traversal of subclass tree is expensive. So we do memory format deduction and coercing at the same time, to keep only one traverse for this. With this approach there is no regression in comparison with previous logic which also does one traversal. (`coerce_tangent_and_suggest_memory_format` method). Other small change: Remove duplicated not-related comment. Testing ``` python test/functorch/test_aotdispatch.py -k test_channels_last_grads_no_force_contiguous ``` Benchmarking: After change: ``` └─ $ PYTORCH_AOTD_DEBUG_PROFILE=1 python test/functorch/test_aotdispatch.py -k test_benchmark_grads_no_force_contiguous Benchmark SUBCLASS avg_bwd_duration:4.059906005859375 ms Benchmark NO_SUBCLASS avg_bwd_duration:3.1563830375671387 ms ``` Before change: ``` BEFORE_CHANGE SUBCLASS 4.1194 ``` No siginificant changes in processing time. (We do single traverse of subclass tree for collecting memory_formats and coercing during tracing.) [ghstack-poisoned]
Original Issue: #134644 We assume trace_tangents to have the same memory_format as inputs, outputs, intermediate during first tracing. => Tracing time: - Store trace_tangents_memory_formats in metadata - Coerce tangents to deduced memory_format Runtime: - Coerce tangents to tracing memory format from metadata Subclasses logic: - Previously coercing tangents logic did not handle nested subclasses case, fixing this. For Subclasses we deduce memory format for subclass_tensor first, then for each element of subclass: [subclass_tensor_memory_format, subclass_tensor_elem0_memory_format, ... ] If subclass element (__tensor_flatten__[0] tensors) is also subclass => on its place we will have a nested list of the same structure. The recursive traversal of subclass tree is expensive. So we do memory format deduction and coercing at the same time, to keep only one traverse for this. With this approach there is no regression in comparison with previous logic which also does one traversal. (`coerce_tangent_and_suggest_memory_format` method). Other small change: Remove duplicated not-related comment. Testing ``` python test/functorch/test_aotdispatch.py -k test_channels_last_grads_no_force_contiguous ``` Benchmarking: After change: ``` └─ $ PYTORCH_AOTD_DEBUG_PROFILE=1 python test/functorch/test_aotdispatch.py -k test_benchmark_grads_no_force_contiguous Benchmark SUBCLASS avg_bwd_duration:4.059906005859375 ms Benchmark NO_SUBCLASS avg_bwd_duration:3.1563830375671387 ms ``` Before change: ``` BEFORE_CHANGE SUBCLASS 4.1194 ``` No siginificant changes in processing time. (We do single traverse of subclass tree for collecting memory_formats and coercing during tracing.) [ghstack-poisoned]
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 2 jobs have failed, first few of them are: inductor / cuda12.1-py3.10-gcc9-sm86 / test (aot_inductor_torchbench, 1, 2, lf.linux.g5.4xlarge.nvidia.gpu), inductor-periodic / cuda12.1-py3.10-gcc9-sm80 / test (inductor_torchbench_smoketest_perf, 1, 1, linux.gcp.a100) Details for Dev Infra teamRaised by workflow job |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Stack from ghstack (oldest at bottom):
Original Issue: #134644
We assume trace_tangents to have the same memory_format as inputs, outputs, intermediate during first tracing.
=>
Tracing time:
Runtime:
Subclasses logic:
For Subclasses we deduce memory format for subclass_tensor first, then for each element of subclass:
[subclass_tensor_memory_format, subclass_tensor_elem0_memory_format, ... ]
If subclass element (tensor_flatten[0] tensors) is also subclass => on its place we will have a nested list of the same structure.
The recursive traversal of subclass tree is expensive. So we do memory format deduction and coercing at the same time, to keep only one traverse for this. With this approach there is no regression in comparison with previous logic which also does one traversal. (
coerce_tangent_and_suggest_memory_formatmethod).Other small change:
Remove duplicated not-related comment.
Testing
Benchmarking:
After change:
Before change:
No siginificant changes in processing time.
(We do single traverse of subclass tree for collecting memory_formats and coercing during tracing.)