Skip to content

Commit e987d75

Browse files
committed
Update on "[inductor] Conditionally copy args to cpu to minimize memory overhead of autotuning"
[ghstack-poisoned]
2 parents 7f0be50 + b198f97 commit e987d75

File tree

4 files changed

+14
-21
lines changed

4 files changed

+14
-21
lines changed

test/inductor/test_cuda_repro.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -422,8 +422,7 @@ def decorator(fn):
422422
configs=configs,
423423
save_cache_hook=False,
424424
mutated_arg_names=["in_out_ptr0"],
425-
is_inference=True,
426-
is_backward=False,
425+
optimize_mem=True,
427426
heuristic_type=HeuristicType.POINTWISE,
428427
)
429428

test/inductor/test_triton_heuristics.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,8 +126,7 @@ def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK: tl.constexpr):
126126
"configs": configs,
127127
"save_cache_hook": False,
128128
"mutated_arg_names": [],
129-
"is_inference": True,
130-
"is_backward": False,
129+
"optimize_mem": True,
131130
"heuristic_type": HeuristicType.POINTWISE,
132131
"inductor_meta": inductor_meta,
133132
}

torch/_inductor/codegen/triton.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2750,12 +2750,16 @@ def codegen_kernel(self, name=None):
27502750
"constants": {},
27512751
}
27522752

2753+
# Skip memory optimization for forward of the training loop where we expect
2754+
# every new node will increase the peak memory and our greedy approach would
2755+
# introduce a lot of unnecessary cpu copies.
2756+
optimize_mem = V.graph.is_inference or V.graph.is_backward
2757+
27532758
inductor_meta = {
27542759
"autotune_hints": set(self.autotune_hints),
27552760
"kernel_name": str(Placeholder.DESCRIPTIVE_NAME),
27562761
"mutated_arg_names": mutated_args,
2757-
"is_inference": V.graph.is_inference,
2758-
"is_backward": V.graph.is_backward,
2762+
"optimize_mem": optimize_mem,
27592763
"no_x_dim": self.no_x_dim,
27602764
"num_load": self.num_load,
27612765
"num_reduction": self.num_reduction,

torch/_inductor/runtime/triton_heuristics.py

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -187,8 +187,7 @@ def __init__(
187187
configs,
188188
save_cache_hook,
189189
mutated_arg_names: List[str], # see [Note: clone mutated buffers]
190-
is_inference,
191-
is_backward,
190+
optimize_mem,
192191
heuristic_type,
193192
size_hints=None,
194193
inductor_meta=None, # metadata not relevant to triton
@@ -212,8 +211,7 @@ def __init__(
212211
self.inductor_meta = {} if inductor_meta is None else inductor_meta
213212
self.save_cache_hook = save_cache_hook
214213
self.mutated_arg_names = mutated_arg_names
215-
self.is_inference = is_inference
216-
self.is_backward = is_backward
214+
self.optimize_mem = optimize_mem
217215
self.configs = configs
218216
self.heuristic_type = heuristic_type
219217
self.custom_kernel = custom_kernel
@@ -718,12 +716,8 @@ def copy_args_to_cpu_if_needed(self, *args, **kwargs):
718716
If those clones would increase the peak memory usage, however, we instead
719717
copy to cpu and restore them after each iteratrion. Figure out the args
720718
to be copied and do the copying.
721-
722-
Skip this optimization for forward of the training loop where we expect
723-
every new node will increase the peak memory and our greedy approach
724-
would introduce a lot of unnecessary cpu copies.
725719
"""
726-
if not self.is_inference and not self.is_backward:
720+
if not self.optimize_mem:
727721
return {}
728722

729723
copies = {}
@@ -1132,8 +1126,7 @@ def cached_autotune(
11321126
log.debug("autotune caching is disabled by config.force_disable_caches")
11331127

11341128
mutated_arg_names = inductor_meta.pop("mutated_arg_names", ())
1135-
is_inference = inductor_meta.pop("is_inference", False)
1136-
is_backward = inductor_meta.pop("is_backward", False)
1129+
optimize_mem = inductor_meta.pop("optimize_mem", True)
11371130

11381131
def decorator(fn):
11391132
# Remove XBLOCK from config if it's not a function argument.
@@ -1160,8 +1153,7 @@ def decorator(fn):
11601153
configs=configs,
11611154
save_cache_hook=autotune_cache and autotune_cache.save,
11621155
mutated_arg_names=mutated_arg_names,
1163-
is_inference=is_inference,
1164-
is_backward=is_backward,
1156+
optimize_mem=optimize_mem,
11651157
heuristic_type=heuristic_type,
11661158
size_hints=size_hints,
11671159
custom_kernel=custom_kernel,
@@ -1174,8 +1166,7 @@ def decorator(fn):
11741166
configs=configs,
11751167
save_cache_hook=autotune_cache and autotune_cache.save,
11761168
mutated_arg_names=mutated_arg_names,
1177-
is_inference=is_inference,
1178-
is_backward=is_backward,
1169+
optimize_mem=optimize_mem,
11791170
heuristic_type=heuristic_type,
11801171
size_hints=size_hints,
11811172
custom_kernel=custom_kernel,

0 commit comments

Comments
 (0)