Skip to content

torch.utils.flop_counter.FlopCounterMode broke with torch-2.4 #134242

@stas00

Description

@stas00

🐛 Describe the bug

We have been successfully using torch.utils.flop_counter.FlopCounterMode up to torch-2.4 and now it breaks and is impossible to use.

It either warns:
The module hierarchy tracking seems to be messed up.Please file a bug to PyTorch
or crashes with:
The Module hierarchy tracking is wrong. Report a bug to PyTorch

The relevant part of the trace is:

[:1]:[rank1]:   File "/home/stas/anaconda3/envs/py310-pt22/lib/python3.10/site-packages/torch/_tensor.py", line 521, in backward
[:1]:[rank1]:     torch.autograd.backward(
[:1]:[rank1]:   File "/home/stas/anaconda3/envs/py310-pt22/lib/python3.10/site-packages/torch/autograd/__init__.py", line 289, in backward
[:1]:[rank1]:     _engine_run_backward(
[:1]:[rank1]:   File "/home/stas/anaconda3/envs/py310-pt22/lib/python3.10/site-packages/torch/autograd/graph.py", line 768, in _engine_run_backward
[:1]:[rank1]:     return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
[:1]:[rank1]:   File "/home/stas/anaconda3/envs/py310-pt22/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 1116, in unpack_hook
[:1]:[rank1]:     frame.recompute_fn(*args)
[:1]:[rank1]:   File "/home/stas/anaconda3/envs/py310-pt22/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 1400, in recompute_fn
[:1]:[rank1]:     fn(*args, **kwargs)
[:1]:[rank1]:   File "/home/stas/anaconda3/envs/py310-pt22/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[:1]:[rank1]:     return self._call_impl(*args, **kwargs)
[:1]:[rank1]:   File "/home/stas/anaconda3/envs/py310-pt22/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1592, in _call_impl
[:1]:[rank1]:     args_result = hook(self, args)
[:1]:[rank1]:   File "/home/stas/anaconda3/envs/py310-pt22/lib/python3.10/site-packages/torch/utils/module_tracker.py", line 120, in _fw_pre_hook
[:1]:[rank1]:     self._get_append_fn(name, False)()
[:1]:[rank1]:   File "/home/stas/anaconda3/envs/py310-pt22/lib/python3.10/site-packages/torch/utils/module_tracker.py", line 96, in fn

here is how we use it:

class MatMulFlopCounter:
    def __init__(self, display=False, target_iter=2):
        self.target_iter = target_iter
        self.flop_counter = FlopCounterMode(display=display)
        self.mm_tflops = 0

    @contextmanager
    def __call__(self, current_iter):
        if current_iter == self.target_iter:
            with self.flop_counter:
                yield
            self.mm_tflops = self.flop_counter.get_total_flops()
        else:
            yield

    def get_total_flops(self):
        return self.mm_tflops / 1e12
[...]
                    mm_flop_counter = MatMulFlopCounter()
                    with mm_flop_counter(iter_since_job_start), self.accelerator.accumulate(self.model):
                        (loss_total, output) = self.do_batch(...)

This happens with any HF transformers model I tried - Bert, Lllama, Mistral - clearly their models are perfectly fine.

Rolling back to 2.3.1 restores the functionality.

Questions:

  1. what is the workaround to unblock us using FlopCounterMode with pt-2.4+
  2. what is the long-term solution

Suggestion:
If I may suggest the warning/error is meaningless to the user. What does "messed up mean"?
In particular this one:

https://github.com/pytorch/pytorch/blob/3c5b246d3c6461ef59fa38e8c4265b2c6b223412/torch/distributed/_tools/mod_tracker.py#L175C10-L177C72

"The module hierarchy tracking maybe be messed up. Please file a bug to PyTorch, if it is the case"- how can a user tell if "it is the case"?

Versions

the problem happens on multiple setups - the only common ground is pt-2.4.0

@albanD, @Chillee

cc @ezyang @gchanan @zou3519 @kadeng @msaroufim

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions