-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Description
🐛 Describe the bug
triton3.2 will bring in a change to the software pipeliner which unifies num_stages behaviour with NV backend. This means enabling pipelining will increase shmem usage, causing flex attention UTs to error out.
inductor.test_flex_attention TestFlexAttention test_builtin_score_mods_bfloat16_score_mod4
{'message': 'torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:\nOutOfResources: out of resource: shared memory, Required: 81920, Hardware limit: 65536. Reducing block sizes or
num_stagesmay help.\n\nSet TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information\n\n\nYou can suppress this exception and fall back to eager by setting:\n import torch._dynamo\n torch._dynamo.config.suppress_errors = True\n\n\nTo execute this test, run the following from the base repo dir:\n PYTORCH_TEST_WITH_ROCM=1 python test/inductor/test_flex_attention.py TestFlexAttention.test_builtin_score_mods_bfloat16_score_mod4\n\nThis message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0', 'text': 'Traceback (most recent call last):\n File "/tmp/pytorch/test/inductor/test_flex_attention.py", line 667, in test_builtin_score_mods\n self.run_test(score_mod, dtype)\n File "/tmp/pytorch/test/inductor/test_flex_attention.py", line 381, in run_test\n compiled_out = compiled_sdpa(q, k, v)\n File "/tmp/pytorch/torch/_dynamo/eval_frame.py", line 554, in _fn\n return fn(*args, **kwargs)\n File "/tmp/pytorch/torch/_dynamo/convert_frame.py", line 1401, in call\n return self._torchdynamo_orig_callable(\n File "/tmp/pytorch/torch/_dynamo/convert_frame.py", line 1184, in call\n result = self._inner_convert(\n File "/tmp/pytorch/torch/_dynamo/convert_frame.py", line 546, in call\n return _compile(\n File "/tmp/pytorch/torch/_dynamo/convert_frame.py", line 979, in _compile\n guarded_code = compile_inner(code, one_graph, hooks, transform)\n File "/tmp/pytorch/torch/_dynamo/convert_frame.py", line 705, in compile_inner\n return _compile_inner(code, one_graph, hooks, transform)\n File "/tmp/pytorch/torch/_utils_internal.py", line 95, in wrapper_function\n return function(*args, **kwargs)\n File "/tmp/pytorch/torch/_dynamo/convert_frame.py", line 740, in _compile_inner\n out_code = transform_code_object(code, transform)\n File "/tmp/pytorch/torch/_dynamo/bytecode_transformation.py", line 1337, in transform_code_object\n transformations(instructions, code_options)\n File "/tmp/pytorch/torch/_dynamo/convert_frame.py", line 231, in _fn\n return fn(*args, **kwargs)\n File "/tmp/pytorch/torch/_dynamo/convert_frame.py", line 659, in transform\n tracer.run()\n File "/tmp/pytorch/torch/_dynamo/symbolic_convert.py", line 2909, in run\n super().run()\n File "/tmp/pytorch/torch/_dynamo/symbolic_convert.py", line 1115, in run\n while self.step():\n File "/tmp/pytorch/torch/_dynamo/symbolic_convert.py", line 1027, in step\n self.dispatch_table[inst.opcode](self, inst)\n File "/tmp/pytorch/torch/_dynamo/symbolic_convert.py", line 3100, in RETURN_VALUE\n self._return(inst)\n File "/tmp/pytorch/torch/_dynamo/symbolic_convert.py", line 3085, in _return\n self.output.compile_subgraph(\n File "/tmp/pytorch/torch/_dynamo/output_graph.py", line 1131, in compile_subgraph\n self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)\n File "/tmp/pytorch/torch/_dynamo/output_graph.py", line 1401, in compile_and_call_fx_graph\n compiled_fn = self.call_user_compiler(gm)\n File "/tmp/pytorch/torch/_dynamo/output_graph.py", line 1448, in call_user_compiler\n return self._call_user_compiler(gm)\n File "/tmp/pytorch/torch/_dynamo/output_graph.py", line 1497, in _call_user_compiler\n raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(\n File "/tmp/pytorch/torch/_dynamo/output_graph.py", line 1478, in call_user_compiler\n compiled_fn = compiler_fn(gm, self.example_inputs())\n File "/tmp/pytorch/torch/dynamo/repro/after_dynamo.py", line 130, in call\n compiled_gm = compiler_fn(gm, example_inputs)\n File "/tmp/pytorch/torch/init.py", line 2278, in call\n return compile_fx(model, inputs, config_patches=self.config)\n File "/tmp/pytorch/torch/_inductor/compile_fx.py", line 1686, in compile_fx\n return aot_autograd(\n File "/tmp/pytorch/torch/_dynamo/backends/common.py", line 72, in call\n cg = aot_module_simplified(gm, example_inputs, **self.kwargs)\n File "/tmp/pytorch/torch/_functorch/aot_autograd.py", line 1105, in aot_module_simplified\n compiled_fn = dispatch_and_compile()\n File "/tmp/pytorch/torch/_functorch/aot_autograd.py", line 1081, in dispatch_and_compile\n compiled_fn, _ = create_aot_dispatcher_function(\n File "/tmp/pytorch/torch/_functorch/aot_autograd.py", line 528, in create_aot_dispatcher_function\n return _create_aot_dispatcher_function(\n File "/tmp/pytorch/torch/_functorch/aot_autograd.py", line 780, in _create_aot_dispatcher_function\n compiled_fn, fw_metadata = compiler_fn(\n File "/tmp/pytorch/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 640, in aot_dispatch_autograd\n compiled_fw_func = aot_config.fw_compiler(fw_module, adjusted_flat_args)\n File "/tmp/pytorch/torch/_inductor/compile_fx.py", line 1506, in fw_compiler_base\n return _fw_compiler_base(model, example_inputs, is_inference)\n File "/tmp/pytorch/torch/_inductor/compile_fx.py", line 1575, in _fw_compiler_base\n return inner_compile(\n File "/tmp/pytorch/torch/_inductor/compile_fx.py", line 578, in compile_fx_inner\n return wrap_compiler_debug(_compile_fx_inner, compiler_name="inductor")(\n File "/tmp/pytorch/torch/_dynamo/repro/after_aot.py", line 100, in debug_wrapper\n inner_compiled_fn = compiler_fn(gm, example_inputs)\n File "/tmp/pytorch/torch/_inductor/compile_fx.py", line 735, in _compile_fx_inner\n compiled_graph = FxGraphCache.load(\n File "/tmp/pytorch/torch/_inductor/codecache.py", line 1479, in load\n compiled_graph = compile_fx_fn(\n File "/tmp/pytorch/torch/_inductor/compile_fx.py", line 642, in codegen_and_compile\n compiled_graph = fx_codegen_and_compile(gm, example_inputs, **fx_kwargs)\n File "/tmp/pytorch/torch/_inductor/compile_fx.py", line 953, in fx_codegen_and_compile\n compiled_fn = graph.compile_to_fn()\n File "/tmp/pytorch/torch/_inductor/graph.py", line 2028, in compile_to_fn\n return self.compile_to_module().call\n File "/tmp/pytorch/torch/_inductor/graph.py", line 1950, in compile_to_module\n return self._compile_to_module()\n File "/tmp/pytorch/torch/_inductor/graph.py", line 1982, in _compile_to_module\n mod = PyCodeCache.load_by_key_path(\n File "/tmp/pytorch/torch/_inductor/codecache.py", line 3025, in load_by_key_path\n mod = _reload_python_module(key, path)\n File "/tmp/pytorch/torch/_inductor/runtime/compile_tasks.py", line 45, in _reload_python_module\n exec(code, mod.dict, mod.dict)\n File "/tmp/tmp9sfrrhe3/o6/co6gnr5cder4xexqym4tvda2raucakithc5fb5gan6zqxkwbo6if.py", line 683, in \n File "/tmp/pytorch/torch/_inductor/async_compile.py", line 308, in wait\n scope[key] = result.result()\n File "/tmp/pytorch/torch/_inductor/codecache.py", line 3495, in result\n self.kernel.precompile()\n File "/tmp/pytorch/torch/_inductor/runtime/triton_heuristics.py", line 271, in precompile\n raise e\n File "/tmp/pytorch/torch/_inductor/runtime/triton_heuristics.py", line 265, in precompile\n compiled_binary, launcher = self._precompile_config(\n File "/tmp/pytorch/torch/_inductor/runtime/triton_heuristics.py", line 488, in _precompile_config\n binary._init_handles()\n File "/root/triton/python/triton/compiler/compiler.py", line 390, in _init_handles\n raise OutOfResources(self.metadata.shared, max_shared, "shared memory")\ntorch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:\nOutOfResources: out of resource: shared memory, Required: 81920, Hardware limit: 65536. Reducing block sizes ornum_stagesmay help.\n\nSet TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information\n\n\nYou can suppress this exception and fall back to eager by setting:\n import torch._dynamo\n torch._dynamo.config.suppress_errors = True\n\n\nTo execute this test, run the following from the base repo dir:\n PYTORCH_TEST_WITH_ROCM=1 python test/inductor/test_flex_attention.py TestFlexAttention.test_builtin_score_mods_bfloat16_score_mod4\n\nThis message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0'}
Will need to tune num_stages in flex attention and some other UTs to avoid this issue
Master Tracker: #139175
Versions
pytorch nightly, tot triton.
cc @jeffdaily @sunway513 @jithunnair-amd @pruthvistony @ROCmSupport @dllehr-amd @hongxiayang @naromero77amd @ezyang @chauhang @penguinwu
Metadata
Metadata
Assignees
Labels
Type
Projects
Status