Skip to content

Commit b20b51c

Browse files
yushangdifacebook-github-bot
authored andcommitted
Fix too big to optimize in test, actually use O0 when aot_inductor.compile_wrapper_with_O0 is set (#148714)
Summary: 1. Check against the "0" char instead 2. We got the following error when using anything other than O0 flag: `error: Function ZN5torch12aot_inductorL22__check_inputs_outputsEPP16AtenTensorOpaqueS3 is too big to optimize [-Werror,-Wignored-optimization-argument]` So we use O0 flag in wrapper code when `aot_inductor.compile_wrapper_opt_level` is set to `O0`. Test Plan: ``` buck run 'fbcode//mode/opt' fbcode//deeplearning/aot_inductor/cpu/test:ads_second_stage_dsnn_models_aoti_lowering_test -- -r AdsSecondStageDSNNModelsAOTILoweringTest ``` Reviewed By: desertfire Differential Revision: D70670957
1 parent 1e37e5b commit b20b51c

File tree

5 files changed

+8
-45
lines changed

5 files changed

+8
-45
lines changed

test/inductor/test_aot_inductor.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -139,24 +139,6 @@ def forward(self, x, y):
139139
model, example_inputs, "AOTInductorModelRunMinimalArrayrefInterface(", 1
140140
)
141141

142-
def test_compile_wrapper_with_O0(self):
143-
class Model(torch.nn.Module):
144-
def __init__(self) -> None:
145-
super().__init__()
146-
self.linear = torch.nn.Linear(10, 10)
147-
148-
def forward(self, x, y):
149-
return x + self.linear(y)
150-
151-
example_inputs = (
152-
torch.randn(10, 10, device=self.device),
153-
torch.randn(10, 10, device=self.device),
154-
)
155-
model = Model()
156-
with config.patch("aot_inductor.compile_wrapper_with_O0", True):
157-
self.check_model(model, example_inputs)
158-
self.code_check_count(model, example_inputs, "__attribute__((", 2)
159-
160142
def test_small_constant(self):
161143
class Model(torch.nn.Module):
162144
def __init__(self) -> None:

torch/_inductor/codegen/cpp_wrapper_cpu.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -407,7 +407,7 @@ def gen_check(handle_kind, idx, name, tensor):
407407
"""
408408
bool _check_aoti_runtime_check_inputs_env() {
409409
const static char* env_var_value = getenv("AOTI_RUNTIME_CHECK_INPUTS");
410-
const static bool result = env_var_value != nullptr && env_var_value[0] != 0;
410+
const static bool result = env_var_value != nullptr && env_var_value[0] != '0';
411411
return result;
412412
}
413413
@@ -461,17 +461,7 @@ def write_wrapper_decl(self):
461461
"""
462462
)
463463

464-
run_impl_proto = ""
465-
if config.aot_inductor.compile_wrapper_with_O0:
466-
run_impl_proto += """
467-
#ifdef __clang__
468-
__attribute__((optnone))
469-
#else
470-
__attribute__((optimize("O0")))
471-
#endif
472-
"""
473-
474-
run_impl_proto += """
464+
run_impl_proto = """
475465
void AOTInductorModel::run_impl(
476466
AtenTensorHandle*
477467
input_handles, // array of input AtenTensorHandle; handles

torch/_inductor/codegen/cpp_wrapper_cpu_array_ref.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -189,17 +189,7 @@ def write_wrapper_decl(self):
189189
"""
190190
)
191191

192-
run_impl_proto = ""
193-
if config.aot_inductor.compile_wrapper_with_O0:
194-
run_impl_proto += """
195-
#ifdef __clang__
196-
__attribute__((optnone))
197-
#else
198-
__attribute__((optimize("O0")))
199-
#endif
200-
"""
201-
202-
run_impl_proto += """
192+
run_impl_proto = """
203193
void AOTInductorModel::run_impl(
204194
AtenTensorHandle*
205195
input_handles, // array of input AtenTensorHandle; handles

torch/_inductor/config.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1168,9 +1168,9 @@ class aot_inductor:
11681168
debug_compile = os.environ.get("AOT_INDUCTOR_DEBUG_COMPILE", "0") == "1"
11691169

11701170
# Annotate generated main wrapper function, i.e. AOTInductorModel::run_impl,
1171-
# to skip cpp compiler optimizations for faster compilation.
1172-
compile_wrapper_with_O0 = (
1173-
os.environ.get("AOT_INDUCTOR_COMPILE_WRAPPER_WITH_O0", "0") == "1"
1171+
# to use which cpp compiler optimization level, default to O1
1172+
compile_wrapper_opt_level = os.environ.get(
1173+
"AOT_INDUCTOR_COMPILE_WRAPPER_OPT_LEVEL", "O1"
11741174
)
11751175

11761176
# option for debug printing/saving for intermediate tensor values for aot inductor

torch/_inductor/cpp_builder.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -542,10 +542,11 @@ def _get_optimization_cflags(
542542
if _IS_WINDOWS:
543543
return ["O1" if min_optimize else "O2"]
544544
else:
545+
wrapper_opt_level = config.aot_inductor.compile_wrapper_opt_level
545546
cflags = (
546547
["O0", "g"]
547548
if config.aot_inductor.debug_compile
548-
else ["O1" if min_optimize else "O3", "DNDEBUG"]
549+
else [wrapper_opt_level if min_optimize else "O3", "DNDEBUG"]
549550
)
550551
cflags += _get_ffast_math_flags()
551552
cflags.append("fno-finite-math-only")

0 commit comments

Comments
 (0)