-
Notifications
You must be signed in to change notification settings - Fork 27k
Closed
Labels
module: dynamic shapesmodule: inductoroncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
With this fix it compiles further, but breaks down the line with `dynamic=True`, with the following error:
[2023-03-06 21:46:13,516] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing generate
[2023-03-06 21:46:13,761] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing _validate_model_class
[2023-03-06 21:46:13,766] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing can_generate
[2023-03-06 21:46:13,770] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing __repr__
[2023-03-06 21:46:13,783] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing __getitem__
[2023-03-06 21:46:13,792] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing _get_abs_string_index
[2023-03-06 21:46:13,794] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing __len__
[2023-03-06 21:46:13,801] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing <graph break in __repr__>
[2023-03-06 21:46:13,843] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing <graph break in __repr__>
[2023-03-06 21:46:13,847] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing __repr__
[2023-03-06 21:46:13,860] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing <graph break in __repr__>
[2023-03-06 21:46:13,894] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing <graph break in _validate_model_class>
[2023-03-06 21:46:13,905] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing <graph break in generate>
[2023-03-06 21:46:13,972] torch._dynamo.convert_frame: [INFO] converting frame raised unsupported, leaving it unconverted
[2023-03-06 21:46:13,972] torch._dynamo.convert_frame: [INFO] converting frame raised unsupported, leaving it unconverted
[2023-03-06 21:46:13,985] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing <graph break in generate>
[2023-03-06 21:46:14,163] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing _validate_model_kwargs
[2023-03-06 21:46:14,173] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing _prepare_model_inputs
[2023-03-06 21:46:14,183] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing _can_retrieve_inputs_from_name
[2023-03-06 21:46:14,187] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing _prepare_attention_mask_for_generation
[2023-03-06 21:46:14,205] torch._dynamo.output_graph: [INFO] Step 2: calling compiler function inductor
[2023-03-06 21:46:15,921] torch._dynamo.output_graph: [INFO] Step 2: done compiler function inductor
[2023-03-06 21:46:15,941] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing _prepare_encoder_decoder_kwargs_for_generation
[2023-03-06 21:46:22,830] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo done tracing _prepare_encoder_decoder_kwargs_for_generation (RETURN_VALUE)
[2023-03-06 21:46:22,938] torch._dynamo.output_graph: [INFO] Step 2: calling compiler function inductor
[2023-03-06 21:46:37,143] torch._inductor.compile_fx: [INFO] Step 3: torchinductor compiling FORWARDS graph 1
[2023-03-06 21:46:42,955] torch._inductor.compile_fx: [INFO] Step 3: torchinductor done compiling FORWARDS graph 1
[2023-03-06 21:46:42,956] torch._dynamo.output_graph: [INFO] Step 2: done compiler function inductor
[2023-03-06 21:46:45,283] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing _prepare_decoder_input_ids_for_generation
[2023-03-06 21:46:45,292] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo done tracing _prepare_decoder_input_ids_for_generation (RETURN_VALUE)
[2023-03-06 21:46:45,294] torch._dynamo.output_graph: [INFO] Step 2: calling compiler function inductor
[2023-03-06 21:46:45,308] torch._inductor.compile_fx: [INFO] Step 3: torchinductor compiling FORWARDS graph 2
[2023-03-06 21:46:45,530] torch._inductor.compile_fx: [INFO] Step 3: torchinductor done compiling FORWARDS graph 2
[2023-03-06 21:46:45,530] torch._dynamo.output_graph: [INFO] Step 2: done compiler function inductor
[2023-03-06 21:46:45,538] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing device
[2023-03-06 21:46:46,697] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing get_parameter_device
[2023-03-06 21:46:47,423] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing _get_logits_processor
[2023-03-06 21:46:47,444] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing <graph break in _get_logits_processor>
[2023-03-06 21:46:47,475] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing <graph break in _get_logits_processor>
[2023-03-06 21:46:47,504] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing _merge_criteria_processor_list
[2023-03-06 21:46:47,510] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing _get_stopping_criteria
[2023-03-06 21:46:47,513] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing <graph break in _get_stopping_criteria>
[2023-03-06 21:46:47,520] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing <graph break in _get_stopping_criteria>
[2023-03-06 21:46:47,528] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing _merge_criteria_processor_list
[2023-03-06 21:46:47,532] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing __init__
[2023-03-06 21:46:47,547] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo done tracing __init__ (RETURN_VALUE)
[2023-03-06 21:46:47,550] torch._dynamo.output_graph: [INFO] Step 2: calling compiler function inductor
[2023-03-06 21:46:47,562] torch._inductor.compile_fx: [INFO] Step 3: torchinductor compiling FORWARDS graph 3
[2023-03-06 21:46:47,674] torch._inductor.compile_fx: [INFO] Step 3: torchinductor done compiling FORWARDS graph 3
[2023-03-06 21:46:47,675] torch._dynamo.output_graph: [INFO] Step 2: done compiler function inductor
[2023-03-06 21:46:47,692] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing _expand_inputs_for_generation
[2023-03-06 21:46:47,726] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing <graph break in _expand_inputs_for_generation>
[2023-03-06 21:46:47,749] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing <graph break in _expand_inputs_for_generation>
[2023-03-06 21:46:47,779] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing <graph break in _expand_inputs_for_generation>
[2023-03-06 21:46:47,816] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing <graph break in _expand_inputs_for_generation>
[2023-03-06 21:46:47,840] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing <graph break in _expand_inputs_for_generation>
[2023-03-06 21:46:47,863] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing <graph break in _expand_inputs_for_generation>
[2023-03-06 21:46:47,879] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing beam_search
[2023-03-06 21:46:47,921] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing <graph break in beam_search>
[2023-03-06 21:46:47,964] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing <graph break in beam_search>
[2023-03-06 21:47:00,847] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing prepare_inputs_for_generation
[2023-03-06 21:47:00,857] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo done tracing prepare_inputs_for_generation (RETURN_VALUE)
[2023-03-06 21:47:00,859] torch._dynamo.output_graph: [INFO] Step 2: calling compiler function inductor
[2023-03-06 21:47:00,884] torch._inductor.compile_fx: [INFO] Step 3: torchinductor compiling FORWARDS graph 4
[2023-03-06 21:47:00,889] torch._inductor.compile_fx: [INFO] Step 3: torchinductor done compiling FORWARDS graph 4
[2023-03-06 21:47:00,889] torch._dynamo.output_graph: [INFO] Step 2: done compiler function inductor
[2023-03-06 21:47:00,908] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing forward
[2023-03-06 21:47:13,469] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo done tracing forward (RETURN_VALUE)
[2023-03-06 21:47:13,655] torch._dynamo.output_graph: [INFO] Step 2: calling compiler function inductor
[2023-03-06 21:47:41,584] torch._inductor.compile_fx: [INFO] Step 3: torchinductor compiling FORWARDS graph 5
[2023-03-06 21:47:51,938] torch._inductor.compile_fx: [INFO] Step 3: torchinductor done compiling FORWARDS graph 5
[2023-03-06 21:47:51,940] torch._dynamo.output_graph: [INFO] Step 2: done compiler function inductor
[2023-03-06 21:47:53,168] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing adjust_logits_during_generation
[2023-03-06 21:47:53,175] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing __call__
[2023-03-06 21:47:53,184] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing __call__
[2023-03-06 21:47:53,200] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo done tracing __call__ (RETURN_VALUE)
[2023-03-06 21:47:53,201] torch._dynamo.output_graph: [INFO] Step 2: calling compiler function inductor
[2023-03-06 21:47:53,266] torch._inductor.compile_fx: [INFO] Step 3: torchinductor compiling FORWARDS graph 6
[2023-03-06 21:47:53,519] torch._inductor.compile_fx: [INFO] Step 3: torchinductor done compiling FORWARDS graph 6
[2023-03-06 21:47:53,520] torch._dynamo.output_graph: [INFO] Step 2: done compiler function inductor
[2023-03-06 21:47:53,530] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing torch_int_div
[2023-03-06 21:47:53,561] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo done tracing torch_int_div (RETURN_VALUE)
[2023-03-06 21:47:53,562] torch._dynamo.output_graph: [INFO] Step 2: calling compiler function inductor
[2023-03-06 21:47:53,643] torch._inductor.compile_fx: [INFO] Step 3: torchinductor compiling FORWARDS graph 7
[2023-03-06 21:47:53,840] torch._inductor.compile_fx: [INFO] Step 3: torchinductor done compiling FORWARDS graph 7
[2023-03-06 21:47:53,841] torch._dynamo.output_graph: [INFO] Step 2: done compiler function inductor
[2023-03-06 21:47:53,858] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing process
[2023-03-06 21:47:53,927] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing _update_model_kwargs_for_generation
[2023-03-06 21:47:53,937] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing _extract_past_from_model_output
[2023-03-06 21:47:53,941] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing __call__
[2023-03-06 21:47:53,944] torch._dynamo.convert_frame: [INFO] converting frame raised unsupported, leaving it unconverted
[2023-03-06 21:47:53,948] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing __call__
[2023-03-06 21:47:53,952] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo done tracing __call__ (RETURN_VALUE)
[2023-03-06 21:47:53,953] torch._dynamo.output_graph: [INFO] Step 2: calling compiler function inductor
[2023-03-06 21:47:53,964] torch._inductor.compile_fx: [INFO] Step 3: torchinductor compiling FORWARDS graph 8
[2023-03-06 21:47:53,965] torch._inductor.graph: [ERROR] Error from lowering
Traceback (most recent call last):
File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/graph.py", line 354, in call_function
out = lowerings[target](*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/lowering.py", line 229, in wrapped
validate_ir(out)
File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/ir.py", line 103, in validate_ir
_check_tensorbox(node_or_nodes)
File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/ir.py", line 88, in _check_tensorbox
assert isinstance(
AssertionError: Found <class 'sympy.core.relational.GreaterThan'>, which is not a supported top level IR node. See [Note: Inductor IR]
Traceback (most recent call last):
File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/graph.py", line 354, in call_function
out = lowerings[target](*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/lowering.py", line 229, in wrapped
validate_ir(out)
File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/ir.py", line 103, in validate_ir
_check_tensorbox(node_or_nodes)
File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/ir.py", line 88, in _check_tensorbox
assert isinstance(
AssertionError: Found <class 'sympy.core.relational.GreaterThan'>, which is not a supported top level IR node. See [Note: Inductor IR]
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 708, in call_user_compiler
compiled_fn = compiler_fn(gm, self.fake_example_inputs())
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/debug_utils.py", line 1055, in debug_wrapper
compiled_gm = compiler_fn(gm, example_inputs)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/backends/inductor.py", line 9, in inductor
return compile_fx(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 488, in compile_fx
return aot_autograd(
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/backends/common.py", line 48, in compiler_fn
cg = aot_module_simplified(gm, example_inputs, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 2818, in aot_module_simplified
compiled_fn = create_aot_dispatcher_function(
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 164, in time_wrapper
r = func(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 2511, in create_aot_dispatcher_function
compiled_fn = compiler_fn(flat_fn, fake_flat_args, aot_config)
File "/opt/conda/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1715, in aot_wrapper_dedupe
return compiler_fn(flat_fn, leaf_flat_args, aot_config)
File "/opt/conda/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1326, in aot_dispatch_base
compiled_fw = aot_config.fw_compiler(fw_module, flat_args_with_views_handled)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 164, in time_wrapper
r = func(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 462, in fw_compiler
return inner_compile(
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/debug_utils.py", line 595, in debug_wrapper
compiled_fn = compiler_fn(gm, example_inputs)
File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/debug.py", line 239, in inner
return fn(*args, **kwargs)
File "/opt/conda/lib/python3.10/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 179, in compile_fx_inner
graph.run(*example_inputs)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 164, in time_wrapper
r = func(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/graph.py", line 211, in run
return super().run(*args)
File "/opt/conda/lib/python3.10/site-packages/torch/fx/interpreter.py", line 136, in run
self.env[node] = self.run_node(node)
File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/graph.py", line 434, in run_node
result = super().run_node(n)
File "/opt/conda/lib/python3.10/site-packages/torch/fx/interpreter.py", line 177, in run_node
return getattr(self, n.op)(n.target, args, kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/_inductor/graph.py", line 358, in call_function
raise LoweringException(e, target, args, kwargs) from e
torch._inductor.exc.LoweringException: AssertionError: Found <class 'sympy.core.relational.GreaterThan'>, which is not a supported top level IR node. See [Note: Inductor IR]
target: <built-in function ge>
args[0]: s1
args[1]: 150
While executing %ge : [#users=1] = call_function[target=operator.ge](args = (%sym_size, 150), kwargs = {})
Original traceback:
None
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/opt/conda/lib/python3.10/runpy.py", line 196, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/opt/conda/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/workspace/foundation-model-stack/nlp/scripts/inference/t5_summarization.py", line 556, in <module>
source_data, predictions, actuals = validate(
File "/workspace/foundation-model-stack/nlp/scripts/inference/t5_summarization.py", line 146, in validate
generated_ids = model.generate2(
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 231, in _fn
return fn(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/transformers/generation/utils.py", line 1177, in generate
self._validate_model_class()
File "/opt/conda/lib/python3.10/site-packages/transformers/generation/utils.py", line 1184, in <graph break in generate>
new_generation_config = GenerationConfig.from_model_config(self.config)
File "/opt/conda/lib/python3.10/site-packages/transformers/generation/utils.py", line 1474, in <graph break in generate>
return self.beam_search(
File "/opt/conda/lib/python3.10/site-packages/transformers/generation/utils.py", line 2652, in beam_search
"`max_length` is deprecated in this function, use"
File "/opt/conda/lib/python3.10/site-packages/transformers/generation/utils.py", line 2803, in <graph break in beam_search>
if beam_scorer.is_done or stopping_criteria(input_ids, scores):
File "/opt/conda/lib/python3.10/site-packages/transformers/generation/stopping_criteria.py", line 113, in __call__
return any(criteria(input_ids, scores) for criteria in self)
File "/opt/conda/lib/python3.10/site-packages/transformers/generation/stopping_criteria.py", line 113, in <genexpr>
return any(criteria(input_ids, scores) for criteria in self)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 368, in catch_errors
return callback(frame, cache_size, hooks)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 404, in _convert_frame
result = inner_convert(frame, cache_size, hooks)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 104, in _fn
return fn(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 262, in _convert_frame_assert
return _compile(
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 164, in time_wrapper
r = func(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 324, in _compile
out_code = transform_code_object(code, transform)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 530, in transform_code_object
transformations(instructions, code_options)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 311, in transform
tracer.run()
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1862, in run
super().run()
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 619, in run
and self.step()
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 583, in step
getattr(self, inst.opname)(inst)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1941, in RETURN_VALUE
self.output.compile_subgraph(
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 579, in compile_subgraph
self.compile_and_call_fx_graph(tx, pass2.graph_output_vars(), root)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 626, in compile_and_call_fx_graph
compiled_fn = self.call_user_compiler(gm)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 164, in time_wrapper
r = func(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 713, in call_user_compiler
raise BackendCompilerFailed(self.compiler_fn, e) from e
torch._dynamo.exc.BackendCompilerFailed: inductor raised LoweringException: AssertionError: Found <class 'sympy.core.relational.GreaterThan'>, which is not a supported top level IR node. See [Note: Inductor IR]
target: <built-in function ge>
args[0]: s1
args[1]: 150
While executing %ge : [#users=1] = call_function[target=operator.ge](args = (%sym_size, 150), kwargs = {})
Original traceback:
None
Set torch._dynamo.config.verbose=True for more information
You can suppress this exception and fall back to eager by setting:
torch._dynamo.config.suppress_errors = True
Originally posted by @ani300 in #96130 (comment)
cc @soumith @msaroufim @wconstab @ngimel @bdhirsh @voznesenskym @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @peterbell10 @desertfire
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
module: dynamic shapesmodule: inductoroncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module