-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Description
🐛 Describe the bug
We have been successfully using torch.utils.flop_counter.FlopCounterMode up to torch-2.4 and now it breaks and is impossible to use.
It either warns:
The module hierarchy tracking seems to be messed up.Please file a bug to PyTorch
or crashes with:
The Module hierarchy tracking is wrong. Report a bug to PyTorch
The relevant part of the trace is:
[:1]:[rank1]: File "/home/stas/anaconda3/envs/py310-pt22/lib/python3.10/site-packages/torch/_tensor.py", line 521, in backward
[:1]:[rank1]: torch.autograd.backward(
[:1]:[rank1]: File "/home/stas/anaconda3/envs/py310-pt22/lib/python3.10/site-packages/torch/autograd/__init__.py", line 289, in backward
[:1]:[rank1]: _engine_run_backward(
[:1]:[rank1]: File "/home/stas/anaconda3/envs/py310-pt22/lib/python3.10/site-packages/torch/autograd/graph.py", line 768, in _engine_run_backward
[:1]:[rank1]: return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
[:1]:[rank1]: File "/home/stas/anaconda3/envs/py310-pt22/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 1116, in unpack_hook
[:1]:[rank1]: frame.recompute_fn(*args)
[:1]:[rank1]: File "/home/stas/anaconda3/envs/py310-pt22/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 1400, in recompute_fn
[:1]:[rank1]: fn(*args, **kwargs)
[:1]:[rank1]: File "/home/stas/anaconda3/envs/py310-pt22/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[:1]:[rank1]: return self._call_impl(*args, **kwargs)
[:1]:[rank1]: File "/home/stas/anaconda3/envs/py310-pt22/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1592, in _call_impl
[:1]:[rank1]: args_result = hook(self, args)
[:1]:[rank1]: File "/home/stas/anaconda3/envs/py310-pt22/lib/python3.10/site-packages/torch/utils/module_tracker.py", line 120, in _fw_pre_hook
[:1]:[rank1]: self._get_append_fn(name, False)()
[:1]:[rank1]: File "/home/stas/anaconda3/envs/py310-pt22/lib/python3.10/site-packages/torch/utils/module_tracker.py", line 96, in fn
here is how we use it:
class MatMulFlopCounter:
def __init__(self, display=False, target_iter=2):
self.target_iter = target_iter
self.flop_counter = FlopCounterMode(display=display)
self.mm_tflops = 0
@contextmanager
def __call__(self, current_iter):
if current_iter == self.target_iter:
with self.flop_counter:
yield
self.mm_tflops = self.flop_counter.get_total_flops()
else:
yield
def get_total_flops(self):
return self.mm_tflops / 1e12
[...]
mm_flop_counter = MatMulFlopCounter()
with mm_flop_counter(iter_since_job_start), self.accelerator.accumulate(self.model):
(loss_total, output) = self.do_batch(...)
This happens with any HF transformers model I tried - Bert, Lllama, Mistral - clearly their models are perfectly fine.
Rolling back to 2.3.1 restores the functionality.
Questions:
- what is the workaround to unblock us using FlopCounterMode with pt-2.4+
- what is the long-term solution
Suggestion:
If I may suggest the warning/error is meaningless to the user. What does "messed up mean"?
In particular this one:
"The module hierarchy tracking maybe be messed up. Please file a bug to PyTorch, if it is the case"- how can a user tell if "it is the case"?
Versions
the problem happens on multiple setups - the only common ground is pt-2.4.0