Skip to content

aot inductor intermediate tensor debug printing (setting 2) not working #145425

@exclamaforte

Description

@exclamaforte

🐛 Describe the bug

Code:

from torch._inductor.fuzzer import ConfigFuzzer, visualize_results #, create_simple_test_model_gpu
import torch

def create_simple_test_model_gpu():
    """Create a simple test model function for demonstration."""

    batch_size = 32
    seq_length = 50
    hidden_size = 768

    def test_fn():
        inp = torch.randn(batch_size, seq_length, hidden_size, device="cuda")
        weight = torch.randn(hidden_size, hidden_size, device="cuda")
        matmul_output = inp @ weight
        final_output = torch.nn.LayerNorm(hidden_size, device="cuda")(matmul_output)
        return True

    return test_fn
tf = create_simple_test_model_gpu()

comp = torch.compile(options={"aot_inductor.debug_intermediate_value_printer": "2"})(tf)
comp()

Error msg:

Traceback (most recent call last):
  File "/home/gabeferns/org/debug/fuzzer-0/bug.py", line 23, in <module>
    comp()
  File "/home/gabeferns/pt-envs/fuzzer/torch/_dynamo/eval_frame.py", line 566, in _fn
    return fn(*args, **kwargs)
  File "/home/gabeferns/org/debug/fuzzer-0/bug.py", line 11, in test_fn
    def test_fn():
  File "/home/gabeferns/pt-envs/fuzzer/torch/_dynamo/eval_frame.py", line 745, in _fn
    return fn(*args, **kwargs)
  File "/home/gabeferns/pt-envs/fuzzer/torch/_functorch/aot_autograd.py", line 1199, in forward
    return compiled_fn(full_args)
  File "/home/gabeferns/pt-envs/fuzzer/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 326, in runtime_wrapper
    all_outs = call_func_at_runtime_with_args(
  File "/home/gabeferns/pt-envs/fuzzer/torch/_functorch/_aot_autograd/utils.py", line 126, in call_func_at_runtime_with_args
    out = normalize_as_list(f(args))
  File "/home/gabeferns/pt-envs/fuzzer/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 687, in inner_fn
    outs = compiled_fn(args)
  File "/home/gabeferns/pt-envs/fuzzer/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 493, in wrapper
    return compiled_fn(runtime_args)
  File "/home/gabeferns/pt-envs/fuzzer/torch/_inductor/output_code.py", line 457, in __call__
    return self.current_callable(inputs)
  File "/tmp/torchinductor_gabeferns/us/cusdgx2jfgdi7skkxb27i4l7xuwe2afa2blsn3kgbqsuldogqqin.py", line 133, in call
    _print_debugging_tensor_value_info("inductor: before_launch - triton_poi_fused_randn_0 - 0", 0)
  File "/home/gabeferns/pt-envs/fuzzer/torch/_inductor/codegen/debug_utils.py", line 26, in _print_debugging_tensor_value_info
    numel = arg.float().numel()
AttributeError: 'int' object has no attribute 'float'

I have a fix incoming.

Versions

git hash: 40e27fb

cc @chauhang @penguinwu

Metadata

Metadata

Assignees

Labels

oncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions