Skip to content

Conversation

@tohtana
Copy link
Collaborator

@tohtana tohtana commented Jun 14, 2025

This PR keeps some of real inputs given to the custom backend for DeepCompile.

DeepCompile expects that the custom backend at TorchFX graph level is always called when recompilation happens. In some cases, however, only the Aten-level backend is called. As the Aten-level backend uses real inputs saved by TorchFX-level backend, we need to keep the real inputs for recompilation.

Currently we discard the real inputs after the Aten-level backend uses it as the real inputs are often too large to keep in GPU memory. This causes an error in cases where recompilation only calls Aten-level backends because we don't have a chance to record new real inputs in TorchFX-level backend.

This PR always keeps only tensor metadata and non-tensor data on CPU and materialize the tensors when needed (i.e. when recompilation happens and only Aten-level backends are called without real inputs). As we use dummy data to materialize tensors, this solution might still not work but improves the coverage.
The new module InputStorage keeps tensor metadata and non-tensor data for this purpose and materialize tensors.

@tohtana tohtana requested review from loadams and tjruwase as code owners June 14, 2025 02:15
@sfc-gh-truwase
Copy link
Collaborator

recompilation happens and only Aten-level backends are called without real inputs). As we use dummy data to materialize tensors, this solution might still not work but improves the coverage.

@tohtana, what are the known failure cases of this solution?

@tohtana
Copy link
Collaborator Author

tohtana commented Jun 15, 2025

@tohtana, what are the known failure cases of this solution?

@sfc-gh-truwase I can think of these two cases:

  • Operators that take indices (e.g. embedding, scatter): They will throw an error if indices that exceed the target tensor size. In this PR, we fill the dummy value with 1 (like ones). Perhaps embedding should be okay for most of cases. If we encounter this issue, it would be good to add an option to control more detail behaviors (e.g. saving only int tensors expecting large activation tensors are mostly float)
  • torch.where, torch.nonzero: Output shapes of these operators can change depending on the input. We will need a different approach to address this (e.g. gather inputs from all ranks and run through them until we get a stable graph)

@sfc-gh-truwase
Copy link
Collaborator

@sfc-gh-truwase I can think of these two cases:

Thanks for the explanation. I have two thoughts:

  1. My understanding is that these limitations are due to runtime properties of execution as opposed to static properties of the code. If so, then they are normal limitations of compiler approach. In that case, we might want to skip compilation for these case as opposed to silent failures.
  2. If my understanding is incorrect, then can we document these limitations in the code until we are able to handle them?

@tohtana
Copy link
Collaborator Author

tohtana commented Jun 17, 2025

@sfc-gh-truwase, Thank you for your comment!
I currently don't think the dummy values affect correctness as they are used only for profiling.

For two cases I mentioned, the followings can happen:

  • Operators that take indices (e.g. embedding, scatter): If an invalid value is given, they throw an error.
  • Operator that produces dynamic shapes depending on inputs (torch.where, torch.nonzero): The output might cause an error with following operators during profiling, or inaccurate profiling results in non-optimal graph modification.

An alternative approach is to offload and keep real inputs. As there is a tradeoff between CPU memory consumption and stability/accuracy of profiling, we could give the user the choice.

@tohtana
Copy link
Collaborator Author

tohtana commented Jun 19, 2025

@sfc-gh-truwase Thanks for the feedback! I've extended the InputStorage to cover wider variety of scenarios.

Here is the summary:

  • Enhanced InputStorage to keep real values for integer tensors by default:
    • Added keep_int_input_tensors: bool = True config option (enabled by default)
    • Integer tensors (indices, masks, etc.) now preserve their actual values instead of using dummy ones
    • This addresses correctness issues with operators like embedding and scatter that rely on valid indices
  • Added option to keep all input tensors:
    • Added keep_all_input_tensors: bool = False config option for comprehensive tensor preservation
    • Useful for debugging or cases where dummy values cause issues with any tensor type

For both options, we offload the real tensor to the host memory.

@tjruwase tjruwase merged commit 6f1a1c0 into master Jun 19, 2025
9 of 10 checks passed
@tjruwase tjruwase deleted the tohtana/keep_real_inputs_for_recompile branch June 19, 2025 12:52
Antlera pushed a commit to Antlera/DeepSpeed that referenced this pull request Jun 27, 2025
This PR keeps some of real inputs given to the custom backend for
DeepCompile.

DeepCompile expects that the custom backend at TorchFX graph level is
always called when recompilation happens. In some cases, however, only
the Aten-level backend is called. As the Aten-level backend uses real
inputs saved by TorchFX-level backend, we need to keep the real inputs
for recompilation.

Currently we discard the real inputs after the Aten-level backend uses
it as the real inputs are often too large to keep in GPU memory. This
causes an error in cases where recompilation only calls Aten-level
backends because we don't have a chance to record new real inputs in
TorchFX-level backend.

This PR always keeps only tensor metadata and non-tensor data on CPU and
materialize the tensors when needed (i.e. when recompilation happens and
only Aten-level backends are called without real inputs). As we use
dummy data to materialize tensors, this solution might still not work
but improves the coverage.
The new module `InputStorage` keeps tensor metadata and non-tensor data
for this purpose and materialize tensors.

---------

Signed-off-by: Masahiro Tanaka <[email protected]>
mauryaavinash95 pushed a commit to DataStates/DeepSpeed that referenced this pull request Oct 4, 2025
This PR keeps some of real inputs given to the custom backend for
DeepCompile.

DeepCompile expects that the custom backend at TorchFX graph level is
always called when recompilation happens. In some cases, however, only
the Aten-level backend is called. As the Aten-level backend uses real
inputs saved by TorchFX-level backend, we need to keep the real inputs
for recompilation.

Currently we discard the real inputs after the Aten-level backend uses
it as the real inputs are often too large to keep in GPU memory. This
causes an error in cases where recompilation only calls Aten-level
backends because we don't have a chance to record new real inputs in
TorchFX-level backend.

This PR always keeps only tensor metadata and non-tensor data on CPU and
materialize the tensors when needed (i.e. when recompilation happens and
only Aten-level backends are called without real inputs). As we use
dummy data to materialize tensors, this solution might still not work
but improves the coverage.
The new module `InputStorage` keeps tensor metadata and non-tensor data
for this purpose and materialize tensors.

---------

Signed-off-by: Masahiro Tanaka <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants