-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[Inductor] auto-chunker #136702
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: gh/shunting314/176/base
Are you sure you want to change the base?
[Inductor] auto-chunker #136702
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/136702
Note: Links to docs will display an error until the docs builds have been completed. ❌ 2 New Failures, 1 Unrelated FailureAs of commit 4fb4e55 with merge base e8de914 ( NEW FAILURES - The following jobs have failed:
FLAKY - The following job failed but was likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
|
@shunting314 @jansel should we try to get this in? what would it take? i can take a pass reviewing it |
|
@v0i0 thanks for offering review. That would be great! I just need find bandwidth to resolve comments, do more tests and get this in. |
|
Yeah, I think we should get this landed. Re-request review when you want me to take another look. |
v0i0
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is amazing! what does the generated code look like? why is this a bunch of invoke_subgraph rather than vmap w/ chunk_size or hop.scan?
| aten.neg.default, | ||
| ] | ||
| ) | ||
| def propagate_general_copy_metadata(out_node, ignore_broadcast=False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
while i admittedly have no idea how to do it, this seems like a special case of propagating along vmap/batching rules. if they exist somewhere, you could probably cut this down a lot.
For linear+cross entropy , the fx graph before chunking is : https://gist.github.com/shunting314/231e4e5c15923f3d16e069768628daaf , the fx graph after chunking is: https://gist.github.com/shunting314/2583db2ea0ea36bdad4ed1a5732bbc49 . |
Thanks for mentioning the similarity with vmap. yea, maybe some functionality can be borrowed/shared. I'm not sure. But for auto-chunker, the different chunks have to be handled serially so we reuse intermediate buffers. hop.scan was not available when I started prototyping auto-chunker. |
The AutoChunker defines the following chunking metadata and propagate them thru the subgraphs that we chunk 1. scale_by: The AutoChunker is only enabled if there is a single scalar tangent. To decouple the dependency on tangent for the bwd subgraph so we can compute them in fwd, we pretend the tangent is 1 first and record it in 'scale_by' . This metadata get propagated and when we cancel the chunking effect in the end of bwd subgraph, we apply the scaling. 2. chunk_dim: record which dimension of the tensor get chunked 3. need_sum: if true, the original Tensor is the sum (rather than concat) of each chunked tensors. One important implementation detail is, we need put chunked subgraph in a HOP (use invoke_subgraph here). Otherwise Inductor fuse across these subgraphs and results in no peak memory saving. This is still a prototype since I assume all chunked input for the chunking subgraph are chunked at the same dimension. But this can be not true. As discussed offline with Jason and Horace, a principled way to resolve this is to propagate the chunking metadata in the backward direction. Here are some early benchmarking result on GPT2. - 64 chunks: - final 19 iters avg: 242.550ms - peak memory consumption: 12603 MiB - 32 chunks: - final 19 iters avg: 206.180ms - peak memory consumption: 12880 MiB - 16 chunks - final 19 iters avg: 196.997ms - peak memory consumption: 13267 MiB - 8 chunks - final 19 iters avg: 194.924ms - peak memory consumption: 14049 MiB With 64 chunks, our peak memory is smaller than llm.c's 13.4GB. I also tried the AutoChunker on PT2 OSS benchmarks to verify the numerical. By default our accuracy test picks a very small batch size. This makes AutoChunker get skipped. I force batch_size to be 16 for BertForMaskedLM to trigger the AutoChunker and verified the numerical correctness. FYI jansel Chillee eellison . Will send an update when this is fully ready for review after I resolve the hacky things mentioned above. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 amjames chauhang aakhundov ColinPeppler desertfire ngimel [ghstack-poisoned]
Amazing! Would it be possible to run this against a bunch of training code (e.g. torchtitan w/ num_gpus=1 & torchbench) to see if anything is missing / get some more memory savings data? |
Sure. That's actually the tests I plan to do and share results. torchtitan/torchtune are both good choices. I'll share more results Our dashboard (HF/torchbench/TIMM) may be stretchy to test AutoChunker since for the small workload, AutoChunker most likely will be skipped. |
The AutoChunker defines the following chunking metadata and propagate them thru the subgraphs that we chunk 1. scale_by: The AutoChunker is only enabled if there is a single scalar tangent. To decouple the dependency on tangent for the bwd subgraph so we can compute them in fwd, we pretend the tangent is 1 first and record it in 'scale_by' . This metadata get propagated and when we cancel the chunking effect in the end of bwd subgraph, we apply the scaling. 2. chunk_dim: record which dimension of the tensor get chunked 3. need_sum: if true, the original Tensor is the sum (rather than concat) of each chunked tensors. One important implementation detail is, we need put chunked subgraph in a HOP (use invoke_subgraph here). Otherwise Inductor fuse across these subgraphs and results in no peak memory saving. This is still a prototype since I assume all chunked input for the chunking subgraph are chunked at the same dimension. But this can be not true. As discussed offline with Jason and Horace, a principled way to resolve this is to propagate the chunking metadata in the backward direction. Here are some early benchmarking result on GPT2. - 64 chunks: - final 19 iters avg: 242.550ms - peak memory consumption: 12603 MiB - 32 chunks: - final 19 iters avg: 206.180ms - peak memory consumption: 12880 MiB - 16 chunks - final 19 iters avg: 196.997ms - peak memory consumption: 13267 MiB - 8 chunks - final 19 iters avg: 194.924ms - peak memory consumption: 14049 MiB With 64 chunks, our peak memory is smaller than llm.c's 13.4GB. I also tried the AutoChunker on PT2 OSS benchmarks to verify the numerical. By default our accuracy test picks a very small batch size. This makes AutoChunker get skipped. I force batch_size to be 16 for BertForMaskedLM to trigger the AutoChunker and verified the numerical correctness. FYI jansel Chillee eellison . Will send an update when this is fully ready for review after I resolve the hacky things mentioned above. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 amjames chauhang aakhundov ColinPeppler desertfire ngimel [ghstack-poisoned]
The AutoChunker defines the following chunking metadata and propagate them thru the subgraphs that we chunk 1. scale_by: The AutoChunker is only enabled if there is a single scalar tangent. To decouple the dependency on tangent for the bwd subgraph so we can compute them in fwd, we pretend the tangent is 1 first and record it in 'scale_by' . This metadata get propagated and when we cancel the chunking effect in the end of bwd subgraph, we apply the scaling. 2. chunk_dim: record which dimension of the tensor get chunked 3. need_sum: if true, the original Tensor is the sum (rather than concat) of each chunked tensors. One important implementation detail is, we need put chunked subgraph in a HOP (use invoke_subgraph here). Otherwise Inductor fuse across these subgraphs and results in no peak memory saving. This is still a prototype since I assume all chunked input for the chunking subgraph are chunked at the same dimension. But this can be not true. As discussed offline with Jason and Horace, a principled way to resolve this is to propagate the chunking metadata in the backward direction. Here are some early benchmarking result on GPT2. - 64 chunks: - final 19 iters avg: 242.550ms - peak memory consumption: 12603 MiB - 32 chunks: - final 19 iters avg: 206.180ms - peak memory consumption: 12880 MiB - 16 chunks - final 19 iters avg: 196.997ms - peak memory consumption: 13267 MiB - 8 chunks - final 19 iters avg: 194.924ms - peak memory consumption: 14049 MiB With 64 chunks, our peak memory is smaller than llm.c's 13.4GB. I also tried the AutoChunker on PT2 OSS benchmarks to verify the numerical. By default our accuracy test picks a very small batch size. This makes AutoChunker get skipped. I force batch_size to be 16 for BertForMaskedLM to trigger the AutoChunker and verified the numerical correctness. FYI jansel Chillee eellison . Will send an update when this is fully ready for review after I resolve the hacky things mentioned above. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 amjames chauhang aakhundov ColinPeppler desertfire ngimel [ghstack-poisoned]
The AutoChunker defines the following chunking metadata and propagate them thru the subgraphs that we chunk 1. scale_by: The AutoChunker is only enabled if there is a single scalar tangent. To decouple the dependency on tangent for the bwd subgraph so we can compute them in fwd, we pretend the tangent is 1 first and record it in 'scale_by' . This metadata get propagated and when we cancel the chunking effect in the end of bwd subgraph, we apply the scaling. 2. chunk_dim: record which dimension of the tensor get chunked 3. need_sum: if true, the original Tensor is the sum (rather than concat) of each chunked tensors. One important implementation detail is, we need put chunked subgraph in a HOP (use invoke_subgraph here). Otherwise Inductor fuse across these subgraphs and results in no peak memory saving. This is still a prototype since I assume all chunked input for the chunking subgraph are chunked at the same dimension. But this can be not true. As discussed offline with Jason and Horace, a principled way to resolve this is to propagate the chunking metadata in the backward direction. Here are some early benchmarking result on GPT2. - 64 chunks: - final 19 iters avg: 242.550ms - peak memory consumption: 12603 MiB - 32 chunks: - final 19 iters avg: 206.180ms - peak memory consumption: 12880 MiB - 16 chunks - final 19 iters avg: 196.997ms - peak memory consumption: 13267 MiB - 8 chunks - final 19 iters avg: 194.924ms - peak memory consumption: 14049 MiB With 64 chunks, our peak memory is smaller than llm.c's 13.4GB. I also tried the AutoChunker on PT2 OSS benchmarks to verify the numerical. By default our accuracy test picks a very small batch size. This makes AutoChunker get skipped. I force batch_size to be 16 for BertForMaskedLM to trigger the AutoChunker and verified the numerical correctness. FYI jansel Chillee eellison . Will send an update when this is fully ready for review after I resolve the hacky things mentioned above. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 amjames chauhang aakhundov ColinPeppler desertfire ngimel [ghstack-poisoned]
The AutoChunker defines the following chunking metadata and propagate them thru the subgraphs that we chunk 1. scale_by: The AutoChunker is only enabled if there is a single scalar tangent. To decouple the dependency on tangent for the bwd subgraph so we can compute them in fwd, we pretend the tangent is 1 first and record it in 'scale_by' . This metadata get propagated and when we cancel the chunking effect in the end of bwd subgraph, we apply the scaling. 2. chunk_dim: record which dimension of the tensor get chunked 3. need_sum: if true, the original Tensor is the sum (rather than concat) of each chunked tensors. One important implementation detail is, we need put chunked subgraph in a HOP (use invoke_subgraph here). Otherwise Inductor fuse across these subgraphs and results in no peak memory saving. This is still a prototype since I assume all chunked input for the chunking subgraph are chunked at the same dimension. But this can be not true. As discussed offline with Jason and Horace, a principled way to resolve this is to propagate the chunking metadata in the backward direction. Here are some early benchmarking result on GPT2. - 64 chunks: - final 19 iters avg: 242.550ms - peak memory consumption: 12603 MiB - 32 chunks: - final 19 iters avg: 206.180ms - peak memory consumption: 12880 MiB - 16 chunks - final 19 iters avg: 196.997ms - peak memory consumption: 13267 MiB - 8 chunks - final 19 iters avg: 194.924ms - peak memory consumption: 14049 MiB With 64 chunks, our peak memory is smaller than llm.c's 13.4GB. I also tried the AutoChunker on PT2 OSS benchmarks to verify the numerical. By default our accuracy test picks a very small batch size. This makes AutoChunker get skipped. I force batch_size to be 16 for BertForMaskedLM to trigger the AutoChunker and verified the numerical correctness. FYI jansel Chillee eellison . Will send an update when this is fully ready for review after I resolve the hacky things mentioned above. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 amjames chauhang aakhundov ColinPeppler desertfire ngimel [ghstack-poisoned]
The AutoChunker defines the following chunking metadata and propagate them thru the subgraphs that we chunk 1. scale_by: The AutoChunker is only enabled if there is a single scalar tangent. To decouple the dependency on tangent for the bwd subgraph so we can compute them in fwd, we pretend the tangent is 1 first and record it in 'scale_by' . This metadata get propagated and when we cancel the chunking effect in the end of bwd subgraph, we apply the scaling. 2. chunk_dim: record which dimension of the tensor get chunked 3. need_sum: if true, the original Tensor is the sum (rather than concat) of each chunked tensors. One important implementation detail is, we need put chunked subgraph in a HOP (use invoke_subgraph here). Otherwise Inductor fuse across these subgraphs and results in no peak memory saving. This is still a prototype since I assume all chunked input for the chunking subgraph are chunked at the same dimension. But this can be not true. As discussed offline with Jason and Horace, a principled way to resolve this is to propagate the chunking metadata in the backward direction. Here are some early benchmarking result on GPT2. - 64 chunks: - final 19 iters avg: 242.550ms - peak memory consumption: 12603 MiB - 32 chunks: - final 19 iters avg: 206.180ms - peak memory consumption: 12880 MiB - 16 chunks - final 19 iters avg: 196.997ms - peak memory consumption: 13267 MiB - 8 chunks - final 19 iters avg: 194.924ms - peak memory consumption: 14049 MiB With 64 chunks, our peak memory is smaller than llm.c's 13.4GB. I also tried the AutoChunker on PT2 OSS benchmarks to verify the numerical. By default our accuracy test picks a very small batch size. This makes AutoChunker get skipped. I force batch_size to be 16 for BertForMaskedLM to trigger the AutoChunker and verified the numerical correctness. FYI jansel Chillee eellison . Will send an update when this is fully ready for review after I resolve the hacky things mentioned above. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 amjames chauhang aakhundov ColinPeppler desertfire ngimel [ghstack-poisoned]
The AutoChunker defines the following chunking metadata and propagate them thru the subgraphs that we chunk 1. scale_by: The AutoChunker is only enabled if there is a single scalar tangent. To decouple the dependency on tangent for the bwd subgraph so we can compute them in fwd, we pretend the tangent is 1 first and record it in 'scale_by' . This metadata get propagated and when we cancel the chunking effect in the end of bwd subgraph, we apply the scaling. 2. chunk_dim: record which dimension of the tensor get chunked 3. need_sum: if true, the original Tensor is the sum (rather than concat) of each chunked tensors. One important implementation detail is, we need put chunked subgraph in a HOP (use invoke_subgraph here). Otherwise Inductor fuse across these subgraphs and results in no peak memory saving. This is still a prototype since I assume all chunked input for the chunking subgraph are chunked at the same dimension. But this can be not true. As discussed offline with Jason and Horace, a principled way to resolve this is to propagate the chunking metadata in the backward direction. Here are some early benchmarking result on GPT2. - 64 chunks: - final 19 iters avg: 242.550ms - peak memory consumption: 12603 MiB - 32 chunks: - final 19 iters avg: 206.180ms - peak memory consumption: 12880 MiB - 16 chunks - final 19 iters avg: 196.997ms - peak memory consumption: 13267 MiB - 8 chunks - final 19 iters avg: 194.924ms - peak memory consumption: 14049 MiB With 64 chunks, our peak memory is smaller than llm.c's 13.4GB. I also tried the AutoChunker on PT2 OSS benchmarks to verify the numerical. By default our accuracy test picks a very small batch size. This makes AutoChunker get skipped. I force batch_size to be 16 for BertForMaskedLM to trigger the AutoChunker and verified the numerical correctness. FYI jansel Chillee eellison . Will send an update when this is fully ready for review after I resolve the hacky things mentioned above. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 amjames chauhang aakhundov ColinPeppler desertfire ngimel [ghstack-poisoned]
The AutoChunker defines the following chunking metadata and propagate them thru the subgraphs that we chunk 1. scale_by: The AutoChunker is only enabled if there is a single scalar tangent. To decouple the dependency on tangent for the bwd subgraph so we can compute them in fwd, we pretend the tangent is 1 first and record it in 'scale_by' . This metadata get propagated and when we cancel the chunking effect in the end of bwd subgraph, we apply the scaling. 2. chunk_dim: record which dimension of the tensor get chunked 3. need_sum: if true, the original Tensor is the sum (rather than concat) of each chunked tensors. One important implementation detail is, we need put chunked subgraph in a HOP (use invoke_subgraph here). Otherwise Inductor fuse across these subgraphs and results in no peak memory saving. This is still a prototype since I assume all chunked input for the chunking subgraph are chunked at the same dimension. But this can be not true. As discussed offline with Jason and Horace, a principled way to resolve this is to propagate the chunking metadata in the backward direction. Here are some early benchmarking result on GPT2. - 64 chunks: - final 19 iters avg: 242.550ms - peak memory consumption: 12603 MiB - 32 chunks: - final 19 iters avg: 206.180ms - peak memory consumption: 12880 MiB - 16 chunks - final 19 iters avg: 196.997ms - peak memory consumption: 13267 MiB - 8 chunks - final 19 iters avg: 194.924ms - peak memory consumption: 14049 MiB With 64 chunks, our peak memory is smaller than llm.c's 13.4GB. I also tried the AutoChunker on PT2 OSS benchmarks to verify the numerical. By default our accuracy test picks a very small batch size. This makes AutoChunker get skipped. I force batch_size to be 16 for BertForMaskedLM to trigger the AutoChunker and verified the numerical correctness. FYI jansel Chillee eellison . Will send an update when this is fully ready for review after I resolve the hacky things mentioned above. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 amjames chauhang aakhundov ColinPeppler desertfire ngimel [ghstack-poisoned]
The AutoChunker defines the following chunking metadata and propagate them thru the subgraphs that we chunk 1. scale_by: The AutoChunker is only enabled if there is a single scalar tangent. To decouple the dependency on tangent for the bwd subgraph so we can compute them in fwd, we pretend the tangent is 1 first and record it in 'scale_by' . This metadata get propagated and when we cancel the chunking effect in the end of bwd subgraph, we apply the scaling. 2. chunk_dim: record which dimension of the tensor get chunked 3. need_sum: if true, the original Tensor is the sum (rather than concat) of each chunked tensors. One important implementation detail is, we need put chunked subgraph in a HOP (use invoke_subgraph here). Otherwise Inductor fuse across these subgraphs and results in no peak memory saving. This is still a prototype since I assume all chunked input for the chunking subgraph are chunked at the same dimension. But this can be not true. As discussed offline with Jason and Horace, a principled way to resolve this is to propagate the chunking metadata in the backward direction. Here are some early benchmarking result on GPT2. - 64 chunks: - final 19 iters avg: 242.550ms - peak memory consumption: 12603 MiB - 32 chunks: - final 19 iters avg: 206.180ms - peak memory consumption: 12880 MiB - 16 chunks - final 19 iters avg: 196.997ms - peak memory consumption: 13267 MiB - 8 chunks - final 19 iters avg: 194.924ms - peak memory consumption: 14049 MiB With 64 chunks, our peak memory is smaller than llm.c's 13.4GB. I also tried the AutoChunker on PT2 OSS benchmarks to verify the numerical. By default our accuracy test picks a very small batch size. This makes AutoChunker get skipped. I force batch_size to be 16 for BertForMaskedLM to trigger the AutoChunker and verified the numerical correctness. FYI jansel Chillee eellison . Will send an update when this is fully ready for review after I resolve the hacky things mentioned above. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 amjames chauhang aakhundov ColinPeppler desertfire ngimel [ghstack-poisoned]
The AutoChunker defines the following chunking metadata and propagate them thru the subgraphs that we chunk 1. scale_by: The AutoChunker is only enabled if there is a single scalar tangent. To decouple the dependency on tangent for the bwd subgraph so we can compute them in fwd, we pretend the tangent is 1 first and record it in 'scale_by' . This metadata get propagated and when we cancel the chunking effect in the end of bwd subgraph, we apply the scaling. 2. chunk_dim: record which dimension of the tensor get chunked 3. need_sum: if true, the original Tensor is the sum (rather than concat) of each chunked tensors. One important implementation detail is, we need put chunked subgraph in a HOP (use invoke_subgraph here). Otherwise Inductor fuse across these subgraphs and results in no peak memory saving. This is still a prototype since I assume all chunked input for the chunking subgraph are chunked at the same dimension. But this can be not true. As discussed offline with Jason and Horace, a principled way to resolve this is to propagate the chunking metadata in the backward direction. Here are some early benchmarking result on GPT2. - 64 chunks: - final 19 iters avg: 242.550ms - peak memory consumption: 12603 MiB - 32 chunks: - final 19 iters avg: 206.180ms - peak memory consumption: 12880 MiB - 16 chunks - final 19 iters avg: 196.997ms - peak memory consumption: 13267 MiB - 8 chunks - final 19 iters avg: 194.924ms - peak memory consumption: 14049 MiB With 64 chunks, our peak memory is smaller than llm.c's 13.4GB. I also tried the AutoChunker on PT2 OSS benchmarks to verify the numerical. By default our accuracy test picks a very small batch size. This makes AutoChunker get skipped. I force batch_size to be 16 for BertForMaskedLM to trigger the AutoChunker and verified the numerical correctness. FYI jansel Chillee eellison . Will send an update when this is fully ready for review after I resolve the hacky things mentioned above. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 amjames chauhang aakhundov ColinPeppler desertfire ngimel [ghstack-poisoned]
Stack from ghstack (oldest at bottom):
The AutoChunker defines the following chunking metadata and propagate them thru the subgraphs that we chunk
One important implementation detail is, we need put chunked subgraph in a HOP (use invoke_subgraph here). Otherwise Inductor fuse across these subgraphs and results in no peak memory saving.
This is still a prototype since I assume all chunked input for the chunking subgraph are chunked at the same dimension. But this can be not true. As discussed offline with Jason and Horace, a principled way to resolve this is to propagate the chunking metadata in the backward direction.
Here are some early benchmarking result on GPT2.
With 64 chunks, our peak memory is smaller than llm.c's 13.4GB.
I also tried the AutoChunker on PT2 OSS benchmarks to verify the numerical. By default our accuracy test picks a very small batch size. This makes AutoChunker get skipped. I force batch_size to be 16 for BertForMaskedLM to trigger the AutoChunker and verified the numerical correctness.
FYI @jansel @Chillee @eellison . Will send an update when this is fully ready for review after I resolve the hacky things mentioned above.
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @ColinPeppler @desertfire @ngimel