@@ -187,8 +187,7 @@ def __init__(
187187 configs ,
188188 save_cache_hook ,
189189 mutated_arg_names : List [str ], # see [Note: clone mutated buffers]
190- is_inference ,
191- is_backward ,
190+ optimize_mem ,
192191 heuristic_type ,
193192 size_hints = None ,
194193 inductor_meta = None , # metadata not relevant to triton
@@ -212,8 +211,7 @@ def __init__(
212211 self .inductor_meta = {} if inductor_meta is None else inductor_meta
213212 self .save_cache_hook = save_cache_hook
214213 self .mutated_arg_names = mutated_arg_names
215- self .is_inference = is_inference
216- self .is_backward = is_backward
214+ self .optimize_mem = optimize_mem
217215 self .configs = configs
218216 self .heuristic_type = heuristic_type
219217 self .custom_kernel = custom_kernel
@@ -718,12 +716,8 @@ def copy_args_to_cpu_if_needed(self, *args, **kwargs):
718716 If those clones would increase the peak memory usage, however, we instead
719717 copy to cpu and restore them after each iteratrion. Figure out the args
720718 to be copied and do the copying.
721-
722- Skip this optimization for forward of the training loop where we expect
723- every new node will increase the peak memory and our greedy approach
724- would introduce a lot of unnecessary cpu copies.
725719 """
726- if not self .is_inference and not self . is_backward :
720+ if not self .optimize_mem :
727721 return {}
728722
729723 copies = {}
@@ -1132,8 +1126,7 @@ def cached_autotune(
11321126 log .debug ("autotune caching is disabled by config.force_disable_caches" )
11331127
11341128 mutated_arg_names = inductor_meta .pop ("mutated_arg_names" , ())
1135- is_inference = inductor_meta .pop ("is_inference" , False )
1136- is_backward = inductor_meta .pop ("is_backward" , False )
1129+ optimize_mem = inductor_meta .pop ("optimize_mem" , True )
11371130
11381131 def decorator (fn ):
11391132 # Remove XBLOCK from config if it's not a function argument.
@@ -1160,8 +1153,7 @@ def decorator(fn):
11601153 configs = configs ,
11611154 save_cache_hook = autotune_cache and autotune_cache .save ,
11621155 mutated_arg_names = mutated_arg_names ,
1163- is_inference = is_inference ,
1164- is_backward = is_backward ,
1156+ optimize_mem = optimize_mem ,
11651157 heuristic_type = heuristic_type ,
11661158 size_hints = size_hints ,
11671159 custom_kernel = custom_kernel ,
@@ -1174,8 +1166,7 @@ def decorator(fn):
11741166 configs = configs ,
11751167 save_cache_hook = autotune_cache and autotune_cache .save ,
11761168 mutated_arg_names = mutated_arg_names ,
1177- is_inference = is_inference ,
1178- is_backward = is_backward ,
1169+ optimize_mem = optimize_mem ,
11791170 heuristic_type = heuristic_type ,
11801171 size_hints = size_hints ,
11811172 custom_kernel = custom_kernel ,
0 commit comments