Skip to content

Conversation

@shunting314
Copy link
Contributor

@shunting314 shunting314 commented Sep 26, 2024

Stack from ghstack (oldest at bottom):

The AutoChunker defines the following chunking metadata and propagate them thru the subgraphs that we chunk

  1. scale_by: The AutoChunker is only enabled if there is a single scalar tangent. To decouple the dependency on tangent for the bwd subgraph so we can compute them in fwd, we pretend the tangent is 1 first and record it in 'scale_by' . This metadata get propagated and when we cancel the chunking effect in the end of bwd subgraph, we apply the scaling.
  2. chunk_dim: record which dimension of the tensor get chunked
  3. need_sum: if true, the original Tensor is the sum (rather than concat) of each chunked tensors.

One important implementation detail is, we need put chunked subgraph in a HOP (use invoke_subgraph here). Otherwise Inductor fuse across these subgraphs and results in no peak memory saving.

This is still a prototype since I assume all chunked input for the chunking subgraph are chunked at the same dimension. But this can be not true. As discussed offline with Jason and Horace, a principled way to resolve this is to propagate the chunking metadata in the backward direction.

Here are some early benchmarking result on GPT2.

  • 64 chunks:
    • final 19 iters avg: 242.550ms
    • peak memory consumption: 12603 MiB
  • 32 chunks:
    • final 19 iters avg: 206.180ms
    • peak memory consumption: 12880 MiB
  • 16 chunks
    • final 19 iters avg: 196.997ms
    • peak memory consumption: 13267 MiB
  • 8 chunks
    • final 19 iters avg: 194.924ms
    • peak memory consumption: 14049 MiB

With 64 chunks, our peak memory is smaller than llm.c's 13.4GB.

I also tried the AutoChunker on PT2 OSS benchmarks to verify the numerical. By default our accuracy test picks a very small batch size. This makes AutoChunker get skipped. I force batch_size to be 16 for BertForMaskedLM to trigger the AutoChunker and verified the numerical correctness.

FYI @jansel @Chillee @eellison . Will send an update when this is fully ready for review after I resolve the hacky things mentioned above.

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @ColinPeppler @desertfire @ngimel

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Sep 26, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/136702

Note: Links to docs will display an error until the docs builds have been completed.

❌ 2 New Failures, 1 Unrelated Failure

As of commit 4fb4e55 with merge base e8de914 (image):

NEW FAILURES - The following jobs have failed:

FLAKY - The following job failed but was likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

shunting314 added a commit that referenced this pull request Sep 26, 2024
ghstack-source-id: 642041d
Pull Request resolved: #136702
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
shunting314 added a commit that referenced this pull request Sep 27, 2024
ghstack-source-id: 93ef4fd
Pull Request resolved: #136702
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
shunting314 added a commit that referenced this pull request Sep 27, 2024
ghstack-source-id: 9042d76
Pull Request resolved: #136702
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
shunting314 added a commit that referenced this pull request Oct 20, 2024
ghstack-source-id: 3e45a50
Pull Request resolved: #136702
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
shunting314 added a commit that referenced this pull request Oct 21, 2024
ghstack-source-id: be119ea
Pull Request resolved: #136702
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
shunting314 added a commit that referenced this pull request Oct 22, 2024
ghstack-source-id: cf853af
Pull Request resolved: #136702
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
shunting314 added a commit that referenced this pull request Oct 22, 2024
ghstack-source-id: 6c741b5
Pull Request resolved: #136702
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
shunting314 added a commit that referenced this pull request Oct 23, 2024
ghstack-source-id: c68ca89
Pull Request resolved: #136702
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
shunting314 added a commit that referenced this pull request Oct 29, 2024
ghstack-source-id: 1c30b64
Pull Request resolved: #136702
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
shunting314 added a commit that referenced this pull request Oct 29, 2024
ghstack-source-id: 6cc3cdb
Pull Request resolved: #136702
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
shunting314 added a commit that referenced this pull request Oct 30, 2024
ghstack-source-id: ca8f475
Pull Request resolved: #136702
@v0i0
Copy link
Contributor

v0i0 commented Nov 14, 2025

@shunting314 @jansel should we try to get this in? what would it take? i can take a pass reviewing it

@shunting314
Copy link
Contributor Author

@v0i0 thanks for offering review. That would be great!

I just need find bandwidth to resolve comments, do more tests and get this in.

@jansel
Copy link
Contributor

jansel commented Nov 16, 2025

Yeah, I think we should get this landed. Re-request review when you want me to take another look.

Copy link
Contributor

@v0i0 v0i0 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is amazing! what does the generated code look like? why is this a bunch of invoke_subgraph rather than vmap w/ chunk_size or hop.scan?

aten.neg.default,
]
)
def propagate_general_copy_metadata(out_node, ignore_broadcast=False):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

while i admittedly have no idea how to do it, this seems like a special case of propagating along vmap/batching rules. if they exist somewhere, you could probably cut this down a lot.

@shunting314
Copy link
Contributor Author

what does the generated code look like?

For linear+cross entropy , the fx graph before chunking is : https://gist.github.com/shunting314/231e4e5c15923f3d16e069768628daaf , the fx graph after chunking is: https://gist.github.com/shunting314/2583db2ea0ea36bdad4ed1a5732bbc49 .
There is a subgraph being extracted out and invoked 4 times because I set number-of-chunk to 4 (can be autotuned/better decided later).

@shunting314
Copy link
Contributor Author

why is this a bunch of invoke_subgraph rather than vmap w/ chunk_size or hop.scan?

Thanks for mentioning the similarity with vmap. yea, maybe some functionality can be borrowed/shared. I'm not sure. But for auto-chunker, the different chunks have to be handled serially so we reuse intermediate buffers.

hop.scan was not available when I started prototyping auto-chunker.

The AutoChunker defines the following chunking metadata and propagate them thru the subgraphs that we chunk
1. scale_by: The AutoChunker is only enabled if there is a single scalar tangent. To decouple the dependency on tangent for the bwd subgraph so we can compute them in fwd, we pretend the tangent is 1 first and record it in 'scale_by' . This metadata get propagated and when we cancel the chunking effect in the end of bwd subgraph, we apply the scaling.
2. chunk_dim: record which dimension of the tensor get chunked
3. need_sum: if true, the original Tensor is the sum (rather than concat) of each chunked tensors.

One important implementation detail is, we need put chunked subgraph in a HOP (use invoke_subgraph here). Otherwise Inductor fuse across these subgraphs and results in no peak memory saving.

This is still a prototype since I assume all chunked input for the chunking subgraph are chunked at the same dimension. But this can be not true. As discussed offline with Jason and Horace, a principled way to resolve this is to propagate the chunking metadata in the backward direction.

Here are some early benchmarking result on GPT2.
- 64 chunks:
   - final 19 iters avg: 242.550ms
   - peak memory consumption: 12603 MiB
- 32 chunks:
   - final 19 iters avg: 206.180ms
   - peak memory consumption: 12880 MiB
- 16 chunks
   - final 19 iters avg: 196.997ms
   - peak memory consumption: 13267 MiB
- 8 chunks
   - final 19 iters avg: 194.924ms
   - peak memory consumption: 14049 MiB

With 64 chunks, our peak memory is smaller than llm.c's 13.4GB.

I also tried the AutoChunker on PT2 OSS benchmarks to verify the numerical. By default our accuracy test picks a very small batch size. This makes AutoChunker get skipped. I force batch_size to be 16 for BertForMaskedLM to trigger the AutoChunker and verified the numerical correctness.

FYI jansel Chillee eellison . Will send an update when this is fully ready for review after I resolve the hacky things mentioned above.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 amjames chauhang aakhundov ColinPeppler desertfire ngimel 

[ghstack-poisoned]
@shunting314 shunting314 requested a review from drisspg as a code owner December 13, 2025 00:25
shunting314 added a commit that referenced this pull request Dec 13, 2025
ghstack-source-id: 192d121
Pull Request resolved: #136702
@v0i0
Copy link
Contributor

v0i0 commented Dec 13, 2025

what does the generated code look like?

For linear+cross entropy , the fx graph before chunking is : https://gist.github.com/shunting314/231e4e5c15923f3d16e069768628daaf , the fx graph after chunking is: https://gist.github.com/shunting314/2583db2ea0ea36bdad4ed1a5732bbc49 . There is a subgraph being extracted out and invoked 4 times because I set number-of-chunk to 4 (can be autotuned/better decided later).

Amazing! Would it be possible to run this against a bunch of training code (e.g. torchtitan w/ num_gpus=1 & torchbench) to see if anything is missing / get some more memory savings data?

@shunting314
Copy link
Contributor Author

shunting314 commented Dec 13, 2025

Amazing! Would it be possible to run this against a bunch of training code (e.g. torchtitan w/ num_gpus=1 & torchbench) to see if anything is missing / get some more memory savings data?

Sure. That's actually the tests I plan to do and share results. torchtitan/torchtune are both good choices. I'll share more results

Our dashboard (HF/torchbench/TIMM) may be stretchy to test AutoChunker since for the small workload, AutoChunker most likely will be skipped.

The AutoChunker defines the following chunking metadata and propagate them thru the subgraphs that we chunk
1. scale_by: The AutoChunker is only enabled if there is a single scalar tangent. To decouple the dependency on tangent for the bwd subgraph so we can compute them in fwd, we pretend the tangent is 1 first and record it in 'scale_by' . This metadata get propagated and when we cancel the chunking effect in the end of bwd subgraph, we apply the scaling.
2. chunk_dim: record which dimension of the tensor get chunked
3. need_sum: if true, the original Tensor is the sum (rather than concat) of each chunked tensors.

One important implementation detail is, we need put chunked subgraph in a HOP (use invoke_subgraph here). Otherwise Inductor fuse across these subgraphs and results in no peak memory saving.

This is still a prototype since I assume all chunked input for the chunking subgraph are chunked at the same dimension. But this can be not true. As discussed offline with Jason and Horace, a principled way to resolve this is to propagate the chunking metadata in the backward direction.

Here are some early benchmarking result on GPT2.
- 64 chunks:
   - final 19 iters avg: 242.550ms
   - peak memory consumption: 12603 MiB
- 32 chunks:
   - final 19 iters avg: 206.180ms
   - peak memory consumption: 12880 MiB
- 16 chunks
   - final 19 iters avg: 196.997ms
   - peak memory consumption: 13267 MiB
- 8 chunks
   - final 19 iters avg: 194.924ms
   - peak memory consumption: 14049 MiB

With 64 chunks, our peak memory is smaller than llm.c's 13.4GB.

I also tried the AutoChunker on PT2 OSS benchmarks to verify the numerical. By default our accuracy test picks a very small batch size. This makes AutoChunker get skipped. I force batch_size to be 16 for BertForMaskedLM to trigger the AutoChunker and verified the numerical correctness.

FYI jansel Chillee eellison . Will send an update when this is fully ready for review after I resolve the hacky things mentioned above.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 amjames chauhang aakhundov ColinPeppler desertfire ngimel 

[ghstack-poisoned]
shunting314 added a commit that referenced this pull request Dec 13, 2025
ghstack-source-id: eb1c257
Pull Request resolved: #136702
The AutoChunker defines the following chunking metadata and propagate them thru the subgraphs that we chunk
1. scale_by: The AutoChunker is only enabled if there is a single scalar tangent. To decouple the dependency on tangent for the bwd subgraph so we can compute them in fwd, we pretend the tangent is 1 first and record it in 'scale_by' . This metadata get propagated and when we cancel the chunking effect in the end of bwd subgraph, we apply the scaling.
2. chunk_dim: record which dimension of the tensor get chunked
3. need_sum: if true, the original Tensor is the sum (rather than concat) of each chunked tensors.

One important implementation detail is, we need put chunked subgraph in a HOP (use invoke_subgraph here). Otherwise Inductor fuse across these subgraphs and results in no peak memory saving.

This is still a prototype since I assume all chunked input for the chunking subgraph are chunked at the same dimension. But this can be not true. As discussed offline with Jason and Horace, a principled way to resolve this is to propagate the chunking metadata in the backward direction.

Here are some early benchmarking result on GPT2.
- 64 chunks:
   - final 19 iters avg: 242.550ms
   - peak memory consumption: 12603 MiB
- 32 chunks:
   - final 19 iters avg: 206.180ms
   - peak memory consumption: 12880 MiB
- 16 chunks
   - final 19 iters avg: 196.997ms
   - peak memory consumption: 13267 MiB
- 8 chunks
   - final 19 iters avg: 194.924ms
   - peak memory consumption: 14049 MiB

With 64 chunks, our peak memory is smaller than llm.c's 13.4GB.

I also tried the AutoChunker on PT2 OSS benchmarks to verify the numerical. By default our accuracy test picks a very small batch size. This makes AutoChunker get skipped. I force batch_size to be 16 for BertForMaskedLM to trigger the AutoChunker and verified the numerical correctness.

FYI jansel Chillee eellison . Will send an update when this is fully ready for review after I resolve the hacky things mentioned above.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 amjames chauhang aakhundov ColinPeppler desertfire ngimel 

[ghstack-poisoned]
shunting314 added a commit that referenced this pull request Dec 15, 2025
ghstack-source-id: 4fb560d
Pull Request resolved: #136702
The AutoChunker defines the following chunking metadata and propagate them thru the subgraphs that we chunk
1. scale_by: The AutoChunker is only enabled if there is a single scalar tangent. To decouple the dependency on tangent for the bwd subgraph so we can compute them in fwd, we pretend the tangent is 1 first and record it in 'scale_by' . This metadata get propagated and when we cancel the chunking effect in the end of bwd subgraph, we apply the scaling.
2. chunk_dim: record which dimension of the tensor get chunked
3. need_sum: if true, the original Tensor is the sum (rather than concat) of each chunked tensors.

One important implementation detail is, we need put chunked subgraph in a HOP (use invoke_subgraph here). Otherwise Inductor fuse across these subgraphs and results in no peak memory saving.

This is still a prototype since I assume all chunked input for the chunking subgraph are chunked at the same dimension. But this can be not true. As discussed offline with Jason and Horace, a principled way to resolve this is to propagate the chunking metadata in the backward direction.

Here are some early benchmarking result on GPT2.
- 64 chunks:
   - final 19 iters avg: 242.550ms
   - peak memory consumption: 12603 MiB
- 32 chunks:
   - final 19 iters avg: 206.180ms
   - peak memory consumption: 12880 MiB
- 16 chunks
   - final 19 iters avg: 196.997ms
   - peak memory consumption: 13267 MiB
- 8 chunks
   - final 19 iters avg: 194.924ms
   - peak memory consumption: 14049 MiB

With 64 chunks, our peak memory is smaller than llm.c's 13.4GB.

I also tried the AutoChunker on PT2 OSS benchmarks to verify the numerical. By default our accuracy test picks a very small batch size. This makes AutoChunker get skipped. I force batch_size to be 16 for BertForMaskedLM to trigger the AutoChunker and verified the numerical correctness.

FYI jansel Chillee eellison . Will send an update when this is fully ready for review after I resolve the hacky things mentioned above.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 amjames chauhang aakhundov ColinPeppler desertfire ngimel 

[ghstack-poisoned]
shunting314 added a commit that referenced this pull request Dec 16, 2025
ghstack-source-id: 6944e01
Pull Request resolved: #136702
The AutoChunker defines the following chunking metadata and propagate them thru the subgraphs that we chunk
1. scale_by: The AutoChunker is only enabled if there is a single scalar tangent. To decouple the dependency on tangent for the bwd subgraph so we can compute them in fwd, we pretend the tangent is 1 first and record it in 'scale_by' . This metadata get propagated and when we cancel the chunking effect in the end of bwd subgraph, we apply the scaling.
2. chunk_dim: record which dimension of the tensor get chunked
3. need_sum: if true, the original Tensor is the sum (rather than concat) of each chunked tensors.

One important implementation detail is, we need put chunked subgraph in a HOP (use invoke_subgraph here). Otherwise Inductor fuse across these subgraphs and results in no peak memory saving.

This is still a prototype since I assume all chunked input for the chunking subgraph are chunked at the same dimension. But this can be not true. As discussed offline with Jason and Horace, a principled way to resolve this is to propagate the chunking metadata in the backward direction.

Here are some early benchmarking result on GPT2.
- 64 chunks:
   - final 19 iters avg: 242.550ms
   - peak memory consumption: 12603 MiB
- 32 chunks:
   - final 19 iters avg: 206.180ms
   - peak memory consumption: 12880 MiB
- 16 chunks
   - final 19 iters avg: 196.997ms
   - peak memory consumption: 13267 MiB
- 8 chunks
   - final 19 iters avg: 194.924ms
   - peak memory consumption: 14049 MiB

With 64 chunks, our peak memory is smaller than llm.c's 13.4GB.

I also tried the AutoChunker on PT2 OSS benchmarks to verify the numerical. By default our accuracy test picks a very small batch size. This makes AutoChunker get skipped. I force batch_size to be 16 for BertForMaskedLM to trigger the AutoChunker and verified the numerical correctness.

FYI jansel Chillee eellison . Will send an update when this is fully ready for review after I resolve the hacky things mentioned above.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 amjames chauhang aakhundov ColinPeppler desertfire ngimel 

[ghstack-poisoned]
shunting314 added a commit that referenced this pull request Dec 16, 2025
ghstack-source-id: 7aadd4f
Pull Request resolved: #136702
The AutoChunker defines the following chunking metadata and propagate them thru the subgraphs that we chunk
1. scale_by: The AutoChunker is only enabled if there is a single scalar tangent. To decouple the dependency on tangent for the bwd subgraph so we can compute them in fwd, we pretend the tangent is 1 first and record it in 'scale_by' . This metadata get propagated and when we cancel the chunking effect in the end of bwd subgraph, we apply the scaling.
2. chunk_dim: record which dimension of the tensor get chunked
3. need_sum: if true, the original Tensor is the sum (rather than concat) of each chunked tensors.

One important implementation detail is, we need put chunked subgraph in a HOP (use invoke_subgraph here). Otherwise Inductor fuse across these subgraphs and results in no peak memory saving.

This is still a prototype since I assume all chunked input for the chunking subgraph are chunked at the same dimension. But this can be not true. As discussed offline with Jason and Horace, a principled way to resolve this is to propagate the chunking metadata in the backward direction.

Here are some early benchmarking result on GPT2.
- 64 chunks:
   - final 19 iters avg: 242.550ms
   - peak memory consumption: 12603 MiB
- 32 chunks:
   - final 19 iters avg: 206.180ms
   - peak memory consumption: 12880 MiB
- 16 chunks
   - final 19 iters avg: 196.997ms
   - peak memory consumption: 13267 MiB
- 8 chunks
   - final 19 iters avg: 194.924ms
   - peak memory consumption: 14049 MiB

With 64 chunks, our peak memory is smaller than llm.c's 13.4GB.

I also tried the AutoChunker on PT2 OSS benchmarks to verify the numerical. By default our accuracy test picks a very small batch size. This makes AutoChunker get skipped. I force batch_size to be 16 for BertForMaskedLM to trigger the AutoChunker and verified the numerical correctness.

FYI jansel Chillee eellison . Will send an update when this is fully ready for review after I resolve the hacky things mentioned above.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 amjames chauhang aakhundov ColinPeppler desertfire ngimel 

[ghstack-poisoned]
shunting314 added a commit that referenced this pull request Dec 17, 2025
ghstack-source-id: 23734a8
Pull Request resolved: #136702
The AutoChunker defines the following chunking metadata and propagate them thru the subgraphs that we chunk
1. scale_by: The AutoChunker is only enabled if there is a single scalar tangent. To decouple the dependency on tangent for the bwd subgraph so we can compute them in fwd, we pretend the tangent is 1 first and record it in 'scale_by' . This metadata get propagated and when we cancel the chunking effect in the end of bwd subgraph, we apply the scaling.
2. chunk_dim: record which dimension of the tensor get chunked
3. need_sum: if true, the original Tensor is the sum (rather than concat) of each chunked tensors.

One important implementation detail is, we need put chunked subgraph in a HOP (use invoke_subgraph here). Otherwise Inductor fuse across these subgraphs and results in no peak memory saving.

This is still a prototype since I assume all chunked input for the chunking subgraph are chunked at the same dimension. But this can be not true. As discussed offline with Jason and Horace, a principled way to resolve this is to propagate the chunking metadata in the backward direction.

Here are some early benchmarking result on GPT2.
- 64 chunks:
   - final 19 iters avg: 242.550ms
   - peak memory consumption: 12603 MiB
- 32 chunks:
   - final 19 iters avg: 206.180ms
   - peak memory consumption: 12880 MiB
- 16 chunks
   - final 19 iters avg: 196.997ms
   - peak memory consumption: 13267 MiB
- 8 chunks
   - final 19 iters avg: 194.924ms
   - peak memory consumption: 14049 MiB

With 64 chunks, our peak memory is smaller than llm.c's 13.4GB.

I also tried the AutoChunker on PT2 OSS benchmarks to verify the numerical. By default our accuracy test picks a very small batch size. This makes AutoChunker get skipped. I force batch_size to be 16 for BertForMaskedLM to trigger the AutoChunker and verified the numerical correctness.

FYI jansel Chillee eellison . Will send an update when this is fully ready for review after I resolve the hacky things mentioned above.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 amjames chauhang aakhundov ColinPeppler desertfire ngimel 

[ghstack-poisoned]
shunting314 added a commit that referenced this pull request Dec 17, 2025
ghstack-source-id: 90b5989
Pull Request resolved: #136702
The AutoChunker defines the following chunking metadata and propagate them thru the subgraphs that we chunk
1. scale_by: The AutoChunker is only enabled if there is a single scalar tangent. To decouple the dependency on tangent for the bwd subgraph so we can compute them in fwd, we pretend the tangent is 1 first and record it in 'scale_by' . This metadata get propagated and when we cancel the chunking effect in the end of bwd subgraph, we apply the scaling.
2. chunk_dim: record which dimension of the tensor get chunked
3. need_sum: if true, the original Tensor is the sum (rather than concat) of each chunked tensors.

One important implementation detail is, we need put chunked subgraph in a HOP (use invoke_subgraph here). Otherwise Inductor fuse across these subgraphs and results in no peak memory saving.

This is still a prototype since I assume all chunked input for the chunking subgraph are chunked at the same dimension. But this can be not true. As discussed offline with Jason and Horace, a principled way to resolve this is to propagate the chunking metadata in the backward direction.

Here are some early benchmarking result on GPT2.
- 64 chunks:
   - final 19 iters avg: 242.550ms
   - peak memory consumption: 12603 MiB
- 32 chunks:
   - final 19 iters avg: 206.180ms
   - peak memory consumption: 12880 MiB
- 16 chunks
   - final 19 iters avg: 196.997ms
   - peak memory consumption: 13267 MiB
- 8 chunks
   - final 19 iters avg: 194.924ms
   - peak memory consumption: 14049 MiB

With 64 chunks, our peak memory is smaller than llm.c's 13.4GB.

I also tried the AutoChunker on PT2 OSS benchmarks to verify the numerical. By default our accuracy test picks a very small batch size. This makes AutoChunker get skipped. I force batch_size to be 16 for BertForMaskedLM to trigger the AutoChunker and verified the numerical correctness.

FYI jansel Chillee eellison . Will send an update when this is fully ready for review after I resolve the hacky things mentioned above.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 amjames chauhang aakhundov ColinPeppler desertfire ngimel 

[ghstack-poisoned]
shunting314 added a commit that referenced this pull request Dec 18, 2025
ghstack-source-id: e8cf647
Pull Request resolved: #136702
The AutoChunker defines the following chunking metadata and propagate them thru the subgraphs that we chunk
1. scale_by: The AutoChunker is only enabled if there is a single scalar tangent. To decouple the dependency on tangent for the bwd subgraph so we can compute them in fwd, we pretend the tangent is 1 first and record it in 'scale_by' . This metadata get propagated and when we cancel the chunking effect in the end of bwd subgraph, we apply the scaling.
2. chunk_dim: record which dimension of the tensor get chunked
3. need_sum: if true, the original Tensor is the sum (rather than concat) of each chunked tensors.

One important implementation detail is, we need put chunked subgraph in a HOP (use invoke_subgraph here). Otherwise Inductor fuse across these subgraphs and results in no peak memory saving.

This is still a prototype since I assume all chunked input for the chunking subgraph are chunked at the same dimension. But this can be not true. As discussed offline with Jason and Horace, a principled way to resolve this is to propagate the chunking metadata in the backward direction.

Here are some early benchmarking result on GPT2.
- 64 chunks:
   - final 19 iters avg: 242.550ms
   - peak memory consumption: 12603 MiB
- 32 chunks:
   - final 19 iters avg: 206.180ms
   - peak memory consumption: 12880 MiB
- 16 chunks
   - final 19 iters avg: 196.997ms
   - peak memory consumption: 13267 MiB
- 8 chunks
   - final 19 iters avg: 194.924ms
   - peak memory consumption: 14049 MiB

With 64 chunks, our peak memory is smaller than llm.c's 13.4GB.

I also tried the AutoChunker on PT2 OSS benchmarks to verify the numerical. By default our accuracy test picks a very small batch size. This makes AutoChunker get skipped. I force batch_size to be 16 for BertForMaskedLM to trigger the AutoChunker and verified the numerical correctness.

FYI jansel Chillee eellison . Will send an update when this is fully ready for review after I resolve the hacky things mentioned above.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 amjames chauhang aakhundov ColinPeppler desertfire ngimel 

[ghstack-poisoned]
The AutoChunker defines the following chunking metadata and propagate them thru the subgraphs that we chunk
1. scale_by: The AutoChunker is only enabled if there is a single scalar tangent. To decouple the dependency on tangent for the bwd subgraph so we can compute them in fwd, we pretend the tangent is 1 first and record it in 'scale_by' . This metadata get propagated and when we cancel the chunking effect in the end of bwd subgraph, we apply the scaling.
2. chunk_dim: record which dimension of the tensor get chunked
3. need_sum: if true, the original Tensor is the sum (rather than concat) of each chunked tensors.

One important implementation detail is, we need put chunked subgraph in a HOP (use invoke_subgraph here). Otherwise Inductor fuse across these subgraphs and results in no peak memory saving.

This is still a prototype since I assume all chunked input for the chunking subgraph are chunked at the same dimension. But this can be not true. As discussed offline with Jason and Horace, a principled way to resolve this is to propagate the chunking metadata in the backward direction.

Here are some early benchmarking result on GPT2.
- 64 chunks:
   - final 19 iters avg: 242.550ms
   - peak memory consumption: 12603 MiB
- 32 chunks:
   - final 19 iters avg: 206.180ms
   - peak memory consumption: 12880 MiB
- 16 chunks
   - final 19 iters avg: 196.997ms
   - peak memory consumption: 13267 MiB
- 8 chunks
   - final 19 iters avg: 194.924ms
   - peak memory consumption: 14049 MiB

With 64 chunks, our peak memory is smaller than llm.c's 13.4GB.

I also tried the AutoChunker on PT2 OSS benchmarks to verify the numerical. By default our accuracy test picks a very small batch size. This makes AutoChunker get skipped. I force batch_size to be 16 for BertForMaskedLM to trigger the AutoChunker and verified the numerical correctness.

FYI jansel Chillee eellison . Will send an update when this is fully ready for review after I resolve the hacky things mentioned above.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 amjames chauhang aakhundov ColinPeppler desertfire ngimel 

[ghstack-poisoned]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants