Skip to content

Conversation

@bdhirsh
Copy link
Contributor

@bdhirsh bdhirsh commented Aug 1, 2022

When functionalization is turned on in AOT Autograd, we want to hide input mutations in the graph so that the backend compiler doesn't need to worry about seeing copy_() ops in the graph. This PR does that by hiding it in an opaque submodule.

Right now this logic happens after the partitioning, and we're relying on partitioning to always leave the copy_() nodes in the forward graph (which... probably needs some more testing, but I think is fine?).

I added light testing for this pass by including it in the existing test_functionalization.py tests, but I'm planning to try hooking this into the torchbench suite, which will let us get rid of this code: https://github.com/pytorch/torchdynamo/blob/5040d49795dde35f0112e27a6744015d44318deb/torchdynamo/optimizations/training.py#L59

Stack from ghstack:

@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Aug 1, 2022

🔗 Helpful links

❌ 14 New Failures

As of commit d592f5d (more details on the Dr. CI page):

Expand to see more
  • 14/14 failures introduced in this PR

🕵️ 14 new failures recognized by patterns

The following CI failures do not appear to be due to upstream breakages

See GitHub Actions build pull / linux-bionic-py3.7-clang9 / test (dynamo, 1, 2, linux.2xlarge) (1/14)

Step: "Test" (full log | diagnosis details)

2022-08-11T18:15:16.9677251Z FAIL [0.035s]: tes..._surface (__main__.TestFXAPIBackwardCompatibility)
2022-08-11T18:15:11.8689273Z   test_type_check_reshape_false (fx.test_gradual_type.TypeCheckerTest) ... ok (0.005s)
2022-08-11T18:15:11.8751912Z   test_type_check_reshape_true (fx.test_gradual_type.TypeCheckerTest) ... ok (0.006s)
2022-08-11T18:15:11.8984559Z   test_type_check_symbolic_inferenceconv2D_maxpool2d_flatten (fx.test_gradual_type.TypeCheckerTest) ... ok (0.023s)
2022-08-11T18:15:11.9042334Z   test_type_check_transpose_False (fx.test_gradual_type.TypeCheckerTest) ... ok (0.006s)
2022-08-11T18:15:11.9106902Z   test_type_check_transpose_true (fx.test_gradual_type.TypeCheckerTest) ... ok (0.006s)
2022-08-11T18:15:16.9004843Z   test_type_maxpool2d_fully_static (fx.test_gradual_type.TypeCheckerTest) ... ok (4.990s)
2022-08-11T18:15:16.9102909Z   test_type_typechecl_maxpool2d_3dinput (fx.test_gradual_type.TypeCheckerTest) ... ok (0.010s)
2022-08-11T18:15:16.9674620Z   test_typecheck_basicblock (fx.test_gradual_type.TypeCheckerTest) ... ok (0.057s)
2022-08-11T18:15:16.9676239Z 
2022-08-11T18:15:16.9676730Z ======================================================================
2022-08-11T18:15:16.9677251Z FAIL [0.035s]: test_public_api_surface (__main__.TestFXAPIBackwardCompatibility)
2022-08-11T18:15:16.9678240Z ----------------------------------------------------------------------
2022-08-11T18:15:16.9678764Z Traceback (most recent call last):
2022-08-11T18:15:16.9679287Z   File "test_fx.py", line 3782, in test_public_api_surface
2022-08-11T18:15:16.9679686Z     def test_public_api_surface(self):
2022-08-11T18:15:16.9680090Z   File "test_fx.py", line 3782, in test_public_api_surface
2022-08-11T18:15:16.9680490Z     def test_public_api_surface(self):
2022-08-11T18:15:16.9680873Z   File "test_fx.py", line 3782, in test_public_api_surface
2022-08-11T18:15:16.9681255Z     def test_public_api_surface(self):
2022-08-11T18:15:16.9681656Z   File "test_fx.py", line 3814, in test_public_api_surface
2022-08-11T18:15:16.9682151Z     raise AssertionError(f"Public FX API(s) {non_back_compat_strs} introduced but not given a "

See GitHub Actions build pull / win-vs2019-cpu-py3 / test (default, 1, 2, windows.4xlarge) (2/14)

Step: "Test" (full log | diagnosis details)

2022-08-11T17:26:28.7042992Z FAIL [0.000s]: tes..._surface (__main__.TestFXAPIBackwardCompatibility)
2022-08-11T17:26:28.6519512Z   test_type_check_reshape_false (fx.test_gradual_type.TypeCheckerTest) ... ok (0.000s)
2022-08-11T17:26:28.6556238Z   test_type_check_reshape_true (fx.test_gradual_type.TypeCheckerTest) ... ok (0.016s)
2022-08-11T17:26:28.6578868Z   test_type_check_symbolic_inferenceconv2D_maxpool2d_flatten (fx.test_gradual_type.TypeCheckerTest) ... skip: no sympy (0.000s)
2022-08-11T17:26:28.6609690Z   test_type_check_transpose_False (fx.test_gradual_type.TypeCheckerTest) ... ok (0.000s)
2022-08-11T17:26:28.6645070Z   test_type_check_transpose_true (fx.test_gradual_type.TypeCheckerTest) ... ok (0.000s)
2022-08-11T17:26:28.6910857Z   test_type_maxpool2d_fully_static (fx.test_gradual_type.TypeCheckerTest) ... ok (0.034s)
2022-08-11T17:26:28.6954309Z   test_type_typechecl_maxpool2d_3dinput (fx.test_gradual_type.TypeCheckerTest) ... ok (0.000s)
2022-08-11T17:26:28.7041660Z   test_typecheck_basicblock (fx.test_gradual_type.TypeCheckerTest) ... ok (0.016s)
2022-08-11T17:26:28.7042183Z 
2022-08-11T17:26:28.7042388Z ======================================================================
2022-08-11T17:26:28.7042992Z FAIL [0.000s]: test_public_api_surface (__main__.TestFXAPIBackwardCompatibility)
2022-08-11T17:26:28.7043701Z ----------------------------------------------------------------------
2022-08-11T17:26:28.7044230Z Traceback (most recent call last):
2022-08-11T17:26:28.7044860Z   File "test_fx.py", line 3814, in test_public_api_surface
2022-08-11T17:26:28.7045657Z     raise AssertionError(f"Public FX API(s) {non_back_compat_strs} introduced but not given a "
2022-08-11T17:26:28.7046726Z AssertionError: Public FX API(s) ['torch.fx.passes.reinplace.reinplace'] introduced but not given a backwards-compatibility classification! Please decorate these API(s) with `@torch.fx._compatibility.compatibility` to specify BC guarantees.
2022-08-11T17:26:28.7047376Z 
2022-08-11T17:26:28.8301456Z ----------------------------------------------------------------------
2022-08-11T17:26:28.8301800Z Ran 1002 tests in 8.219s
2022-08-11T17:26:28.8301920Z 
2022-08-11T17:26:28.8302042Z FAILED (failures=1, skipped=208, expected failures=5)

See GitHub Actions build pull / linux-bionic-cuda11.6-py3.10-gcc7 / test (default, 3, 4, linux.4xlarge.nvidia.gpu) (3/14)

Step: "Test" (full log | diagnosis details)

2022-08-11T16:26:18.2687847Z FAIL [0.003s]: tes..._surface (__main__.TestFXAPIBackwardCompatibility)
2022-08-11T16:26:18.1996159Z   test_type_check_reshape_false (fx.test_gradual_type.TypeCheckerTest) ... ok (0.003s)
2022-08-11T16:26:18.2028634Z   test_type_check_reshape_true (fx.test_gradual_type.TypeCheckerTest) ... ok (0.003s)
2022-08-11T16:26:18.2213449Z   test_type_check_symbolic_inferenceconv2D_maxpool2d_flatten (fx.test_gradual_type.TypeCheckerTest) ... ok (0.018s)
2022-08-11T16:26:18.2246146Z   test_type_check_transpose_False (fx.test_gradual_type.TypeCheckerTest) ... ok (0.003s)
2022-08-11T16:26:18.2279090Z   test_type_check_transpose_true (fx.test_gradual_type.TypeCheckerTest) ... ok (0.003s)
2022-08-11T16:26:18.2562426Z   test_type_maxpool2d_fully_static (fx.test_gradual_type.TypeCheckerTest) ... ok (0.028s)
2022-08-11T16:26:18.2602878Z   test_type_typechecl_maxpool2d_3dinput (fx.test_gradual_type.TypeCheckerTest) ... ok (0.004s)
2022-08-11T16:26:18.2686271Z   test_typecheck_basicblock (fx.test_gradual_type.TypeCheckerTest) ... ok (0.008s)
2022-08-11T16:26:18.2686812Z 
2022-08-11T16:26:18.2687054Z ======================================================================
2022-08-11T16:26:18.2687847Z FAIL [0.003s]: test_public_api_surface (__main__.TestFXAPIBackwardCompatibility)
2022-08-11T16:26:18.2688855Z ----------------------------------------------------------------------
2022-08-11T16:26:18.2689226Z Traceback (most recent call last):
2022-08-11T16:26:18.2689619Z   File "/var/lib/jenkins/workspace/test/test_fx.py", line 3814, in test_public_api_surface
2022-08-11T16:26:18.2690329Z     raise AssertionError(f"Public FX API(s) {non_back_compat_strs} introduced but not given a "
2022-08-11T16:26:18.2691272Z AssertionError: Public FX API(s) ['torch.fx.passes.reinplace.reinplace'] introduced but not given a backwards-compatibility classification! Please decorate these API(s) with `@torch.fx._compatibility.compatibility` to specify BC guarantees.
2022-08-11T16:26:18.2697323Z 
2022-08-11T16:26:18.2698719Z ----------------------------------------------------------------------
2022-08-11T16:26:18.2699360Z Ran 1009 tests in 4.163s
2022-08-11T16:26:18.2699542Z 
2022-08-11T16:26:18.2699743Z FAILED (failures=1, skipped=598, expected failures=3)

See GitHub Actions build pull / linux-bionic-py3.7-clang9 / test (functorch, 1, 1, linux.2xlarge) (4/14)

Step: "Test" (full log | diagnosis details)

2022-08-11T16:05:19.0650494Z ModuleNotFoundError: No module named 'functorch'
2022-08-11T16:05:18.1408444Z prioritized: []
2022-08-11T16:05:18.1409578Z the rest: ['/var/lib/jenkins/workspace/functorch/test/test_compile_cache', '/var/lib/jenkins/workspace/functorch/test/test_dims', '/var/lib/jenkins/workspace/functorch/test/test_eager_transforms', '/var/lib/jenkins/workspace/functorch/test/test_functionalize', '/var/lib/jenkins/workspace/functorch/test/test_memory_efficient_fusion', '/var/lib/jenkins/workspace/functorch/test/test_minifier', '/var/lib/jenkins/workspace/functorch/test/test_ops', '/var/lib/jenkins/workspace/functorch/test/test_pythonkey', '/var/lib/jenkins/workspace/functorch/test/test_vmap']
2022-08-11T16:05:18.1410424Z 
2022-08-11T16:05:18.1410832Z Downloading https://raw.githubusercontent.com/pytorch/test-infra/generated-stats/stats/slow-tests.json to /var/lib/jenkins/workspace/test/.pytorch-slow-tests.json
2022-08-11T16:05:18.1588225Z Downloading https://raw.githubusercontent.com/pytorch/test-infra/generated-stats/stats/disabled-tests.json to /var/lib/jenkins/workspace/test/.pytorch-disabled-tests.json
2022-08-11T16:05:18.1889801Z Running /var/lib/jenkins/workspace/functorch/test/test_compile_cache ... [2022-08-11 16:05:18.188532]
2022-08-11T16:05:18.1890516Z Executing ['/opt/conda/bin/python', '-bb', '/var/lib/jenkins/workspace/functorch/test/test_compile_cache.py', '-v', '--import-slow-tests', '--import-disabled-tests'] ... [2022-08-11 16:05:18.188590]
2022-08-11T16:05:19.0648744Z Traceback (most recent call last):
2022-08-11T16:05:19.0649382Z   File "/var/lib/jenkins/workspace/functorch/test/test_compile_cache.py", line 5, in <module>
2022-08-11T16:05:19.0649883Z     import functorch
2022-08-11T16:05:19.0650494Z ModuleNotFoundError: No module named 'functorch'
2022-08-11T16:05:19.2230455Z Traceback (most recent call last):
2022-08-11T16:05:19.2230961Z   File "test/run_test.py", line 974, in <module>
2022-08-11T16:05:19.2232461Z     main()
2022-08-11T16:05:19.2232679Z   File "test/run_test.py", line 952, in main
2022-08-11T16:05:19.2235121Z     raise RuntimeError(err_message)
2022-08-11T16:05:19.2235451Z RuntimeError: /var/lib/jenkins/workspace/functorch/test/test_compile_cache failed!
2022-08-11T16:05:19.5049790Z ##[error]Process completed with exit code 1.
2022-08-11T16:05:19.5084132Z Prepare all required actions
2022-08-11T16:05:19.5084440Z Getting action download info
2022-08-11T16:05:19.6623880Z ##[group]Run ./.github/actions/get-workflow-job-id

See GitHub Actions build pull / linux-focal-py3.7-gcc7 / test (default, 1, 2, linux.2xlarge) (5/14)

Step: "Test" (full log | diagnosis details)

2022-08-11T16:10:35.0440172Z FAIL [0.002s]: tes..._surface (__main__.TestFXAPIBackwardCompatibility)
2022-08-11T16:10:34.9771850Z   test_type_check_reshape_false (fx.test_gradual_type.TypeCheckerTest) ... ok (0.003s)
2022-08-11T16:10:34.9803935Z   test_type_check_reshape_true (fx.test_gradual_type.TypeCheckerTest) ... ok (0.003s)
2022-08-11T16:10:34.9974235Z   test_type_check_symbolic_inferenceconv2D_maxpool2d_flatten (fx.test_gradual_type.TypeCheckerTest) ... ok (0.017s)
2022-08-11T16:10:35.0012552Z   test_type_check_transpose_False (fx.test_gradual_type.TypeCheckerTest) ... ok (0.004s)
2022-08-11T16:10:35.0047206Z   test_type_check_transpose_true (fx.test_gradual_type.TypeCheckerTest) ... ok (0.003s)
2022-08-11T16:10:35.0311763Z   test_type_maxpool2d_fully_static (fx.test_gradual_type.TypeCheckerTest) ... ok (0.026s)
2022-08-11T16:10:35.0359229Z   test_type_typechecl_maxpool2d_3dinput (fx.test_gradual_type.TypeCheckerTest) ... ok (0.005s)
2022-08-11T16:10:35.0437423Z   test_typecheck_basicblock (fx.test_gradual_type.TypeCheckerTest) ... ok (0.008s)
2022-08-11T16:10:35.0439585Z 
2022-08-11T16:10:35.0439784Z ======================================================================
2022-08-11T16:10:35.0440172Z FAIL [0.002s]: test_public_api_surface (__main__.TestFXAPIBackwardCompatibility)
2022-08-11T16:10:35.0441050Z ----------------------------------------------------------------------
2022-08-11T16:10:35.0441368Z Traceback (most recent call last):
2022-08-11T16:10:35.0441693Z   File "test_fx.py", line 3814, in test_public_api_surface
2022-08-11T16:10:35.0442087Z     raise AssertionError(f"Public FX API(s) {non_back_compat_strs} introduced but not given a "
2022-08-11T16:10:35.0443041Z AssertionError: Public FX API(s) ['torch.fx.passes.reinplace.reinplace'] introduced but not given a backwards-compatibility classification! Please decorate these API(s) with `@torch.fx._compatibility.compatibility` to specify BC guarantees.
2022-08-11T16:10:35.0443533Z 
2022-08-11T16:10:35.0443796Z ----------------------------------------------------------------------
2022-08-11T16:10:35.0444098Z Ran 1100 tests in 210.877s
2022-08-11T16:10:35.0444244Z 
2022-08-11T16:10:35.0444466Z FAILED (failures=1, skipped=194, expected failures=5)

See GitHub Actions build pull / linux-focal-py3.7-gcc7 / test (default, 2, 2, linux.2xlarge) (6/14)

Step: "Test" (full log | diagnosis details)

2022-08-11T16:14:16.4381420Z RuntimeError: test_functionalization failed!
2022-08-11T16:14:15.1476461Z Executing ['/opt/conda/bin/python', '-bb', 'test_functionalization.py', '-v', '--import-slow-tests', '--import-disabled-tests'] ... [2022-08-11 16:14:15.147266]
2022-08-11T16:14:16.2395332Z Traceback (most recent call last):
2022-08-11T16:14:16.2395801Z   File "test_functionalization.py", line 8, in <module>
2022-08-11T16:14:16.2396118Z     from functorch.experimental import functionalize
2022-08-11T16:14:16.2397148Z ModuleNotFoundError: No module named 'functorch'
2022-08-11T16:14:16.4377117Z Traceback (most recent call last):
2022-08-11T16:14:16.4377597Z   File "test/run_test.py", line 974, in <module>
2022-08-11T16:14:16.4379081Z     main()
2022-08-11T16:14:16.4379447Z   File "test/run_test.py", line 952, in main
2022-08-11T16:14:16.4381021Z     raise RuntimeError(err_message)
2022-08-11T16:14:16.4381420Z RuntimeError: test_functionalization failed!
2022-08-11T16:14:16.7508011Z 
2022-08-11T16:14:16.7508356Z real	7m22.494s
2022-08-11T16:14:16.7508710Z user	8m47.952s
2022-08-11T16:14:16.7509016Z sys	0m4.428s
2022-08-11T16:14:16.7539056Z ##[error]Process completed with exit code 1.
2022-08-11T16:14:16.7573677Z Prepare all required actions
2022-08-11T16:14:16.7573968Z Getting action download info
2022-08-11T16:14:16.9469757Z ##[group]Run ./.github/actions/get-workflow-job-id
2022-08-11T16:14:16.9469967Z with:
2022-08-11T16:14:16.9470302Z   github-token: ***

See GitHub Actions build pull / linux-bionic-py3.7-clang9 / test (crossref, 2, 2, linux.2xlarge) (7/14)

Step: "Test" (full log | diagnosis details)

2022-08-11T16:14:14.7157935Z RuntimeError: test_functionalization failed!
2022-08-11T16:14:13.4683786Z Executing ['/opt/conda/bin/python', '-bb', 'test_functionalization.py', '-v', '--import-slow-tests', '--import-disabled-tests'] ... [2022-08-11 16:14:13.467976]
2022-08-11T16:14:14.5228363Z Traceback (most recent call last):
2022-08-11T16:14:14.5228732Z   File "test_functionalization.py", line 8, in <module>
2022-08-11T16:14:14.5229013Z     from functorch.experimental import functionalize
2022-08-11T16:14:14.5229468Z ModuleNotFoundError: No module named 'functorch'
2022-08-11T16:14:14.7153163Z Traceback (most recent call last):
2022-08-11T16:14:14.7153491Z   File "test/run_test.py", line 974, in <module>
2022-08-11T16:14:14.7155841Z     main()
2022-08-11T16:14:14.7156031Z   File "test/run_test.py", line 952, in main
2022-08-11T16:14:14.7157696Z     raise RuntimeError(err_message)
2022-08-11T16:14:14.7157935Z RuntimeError: test_functionalization failed!
2022-08-11T16:14:15.0120680Z 
2022-08-11T16:14:15.0121116Z real	7m45.428s
2022-08-11T16:14:15.0121495Z user	11m12.250s
2022-08-11T16:14:15.0121696Z sys	0m12.321s
2022-08-11T16:14:15.0151408Z ##[error]Process completed with exit code 1.
2022-08-11T16:14:15.0187156Z Prepare all required actions
2022-08-11T16:14:15.0187459Z Getting action download info
2022-08-11T16:14:15.2119133Z ##[group]Run ./.github/actions/get-workflow-job-id
2022-08-11T16:14:15.2119359Z with:
2022-08-11T16:14:15.2119698Z   github-token: ***

See GitHub Actions build pull / linux-focal-py3.7-gcc7 / test (functorch, 1, 1, linux.2xlarge) (8/14)

Step: "Test" (full log | diagnosis details)

2022-08-11T16:05:42.0550027Z ModuleNotFoundError: No module named 'functorch'
2022-08-11T16:05:41.0606545Z prioritized: []
2022-08-11T16:05:41.0607807Z the rest: ['/var/lib/jenkins/workspace/functorch/test/test_compile_cache', '/var/lib/jenkins/workspace/functorch/test/test_dims', '/var/lib/jenkins/workspace/functorch/test/test_eager_transforms', '/var/lib/jenkins/workspace/functorch/test/test_functionalize', '/var/lib/jenkins/workspace/functorch/test/test_memory_efficient_fusion', '/var/lib/jenkins/workspace/functorch/test/test_minifier', '/var/lib/jenkins/workspace/functorch/test/test_ops', '/var/lib/jenkins/workspace/functorch/test/test_pythonkey', '/var/lib/jenkins/workspace/functorch/test/test_vmap']
2022-08-11T16:05:41.0608446Z 
2022-08-11T16:05:41.0608876Z Downloading https://raw.githubusercontent.com/pytorch/test-infra/generated-stats/stats/slow-tests.json to /var/lib/jenkins/workspace/test/.pytorch-slow-tests.json
2022-08-11T16:05:41.0960487Z Downloading https://raw.githubusercontent.com/pytorch/test-infra/generated-stats/stats/disabled-tests.json to /var/lib/jenkins/workspace/test/.pytorch-disabled-tests.json
2022-08-11T16:05:41.1344658Z Running /var/lib/jenkins/workspace/functorch/test/test_compile_cache ... [2022-08-11 16:05:41.134122]
2022-08-11T16:05:41.1345367Z Executing ['/opt/conda/bin/python', '-bb', '/var/lib/jenkins/workspace/functorch/test/test_compile_cache.py', '-v', '--import-slow-tests', '--import-disabled-tests'] ... [2022-08-11 16:05:41.134175]
2022-08-11T16:05:42.0549009Z Traceback (most recent call last):
2022-08-11T16:05:42.0549354Z   File "/var/lib/jenkins/workspace/functorch/test/test_compile_cache.py", line 5, in <module>
2022-08-11T16:05:42.0549636Z     import functorch
2022-08-11T16:05:42.0550027Z ModuleNotFoundError: No module named 'functorch'
2022-08-11T16:05:42.2204796Z Traceback (most recent call last):
2022-08-11T16:05:42.2205091Z   File "test/run_test.py", line 974, in <module>
2022-08-11T16:05:42.2207262Z     main()
2022-08-11T16:05:42.2207500Z   File "test/run_test.py", line 952, in main
2022-08-11T16:05:42.2209557Z     raise RuntimeError(err_message)
2022-08-11T16:05:42.2209931Z RuntimeError: /var/lib/jenkins/workspace/functorch/test/test_compile_cache failed!
2022-08-11T16:05:42.5170764Z ##[error]Process completed with exit code 1.
2022-08-11T16:05:42.5208710Z Prepare all required actions
2022-08-11T16:05:42.5209029Z Getting action download info
2022-08-11T16:05:42.7160700Z ##[group]Run ./.github/actions/get-workflow-job-id

See GitHub Actions build pull / linux-focal-py3.7-clang7-asan / test (default, 3, 5, linux.2xlarge) (9/14)

Step: "Test" (full log | diagnosis details)

2022-08-11T16:05:12.7214661Z RuntimeError: test_functionalization failed!
2022-08-11T16:05:09.2802242Z Executing ['/opt/conda/bin/python', '-bb', 'test_functionalization.py', '-v', '--import-slow-tests', '--import-disabled-tests'] ... [2022-08-11 16:05:09.279760]
2022-08-11T16:05:12.2589962Z Traceback (most recent call last):
2022-08-11T16:05:12.2590421Z   File "test_functionalization.py", line 8, in <module>
2022-08-11T16:05:12.2590690Z     from functorch.experimental import functionalize
2022-08-11T16:05:12.2591110Z ModuleNotFoundError: No module named 'functorch'
2022-08-11T16:05:12.7206685Z Traceback (most recent call last):
2022-08-11T16:05:12.7206974Z   File "test/run_test.py", line 974, in <module>
2022-08-11T16:05:12.7210539Z     main()
2022-08-11T16:05:12.7210796Z   File "test/run_test.py", line 952, in main
2022-08-11T16:05:12.7214352Z     raise RuntimeError(err_message)
2022-08-11T16:05:12.7214661Z RuntimeError: test_functionalization failed!
2022-08-11T16:05:13.3085465Z 
2022-08-11T16:05:13.3085874Z real	0m28.952s
2022-08-11T16:05:13.3086115Z user	0m26.979s
2022-08-11T16:05:13.3086293Z sys	0m5.238s
2022-08-11T16:05:13.3112620Z ##[error]Process completed with exit code 1.
2022-08-11T16:05:13.3147554Z Prepare all required actions
2022-08-11T16:05:13.3147882Z Getting action download info
2022-08-11T16:05:13.5140861Z ##[group]Run ./.github/actions/get-workflow-job-id
2022-08-11T16:05:13.5141083Z with:
2022-08-11T16:05:13.5141409Z   github-token: ***

See GitHub Actions build pull / linux-bionic-py3.7-clang9 / test (default, 1, 2, linux.2xlarge) (10/14)

Step: "Test" (full log | diagnosis details)

2022-08-11T16:10:04.6376457Z FAIL [0.002s]: tes..._surface (__main__.TestFXAPIBackwardCompatibility)
2022-08-11T16:10:04.5757277Z   test_type_check_reshape_false (fx.test_gradual_type.TypeCheckerTest) ... ok (0.003s)
2022-08-11T16:10:04.5788086Z   test_type_check_reshape_true (fx.test_gradual_type.TypeCheckerTest) ... ok (0.003s)
2022-08-11T16:10:04.5951905Z   test_type_check_symbolic_inferenceconv2D_maxpool2d_flatten (fx.test_gradual_type.TypeCheckerTest) ... ok (0.016s)
2022-08-11T16:10:04.5981231Z   test_type_check_transpose_False (fx.test_gradual_type.TypeCheckerTest) ... ok (0.003s)
2022-08-11T16:10:04.6012666Z   test_type_check_transpose_true (fx.test_gradual_type.TypeCheckerTest) ... ok (0.003s)
2022-08-11T16:10:04.6261372Z   test_type_maxpool2d_fully_static (fx.test_gradual_type.TypeCheckerTest) ... ok (0.025s)
2022-08-11T16:10:04.6304056Z   test_type_typechecl_maxpool2d_3dinput (fx.test_gradual_type.TypeCheckerTest) ... ok (0.004s)
2022-08-11T16:10:04.6374725Z   test_typecheck_basicblock (fx.test_gradual_type.TypeCheckerTest) ... ok (0.007s)
2022-08-11T16:10:04.6375390Z 
2022-08-11T16:10:04.6375846Z ======================================================================
2022-08-11T16:10:04.6376457Z FAIL [0.002s]: test_public_api_surface (__main__.TestFXAPIBackwardCompatibility)
2022-08-11T16:10:04.6377796Z ----------------------------------------------------------------------
2022-08-11T16:10:04.6379623Z Traceback (most recent call last):
2022-08-11T16:10:04.6380075Z   File "test_fx.py", line 3814, in test_public_api_surface
2022-08-11T16:10:04.6380600Z     raise AssertionError(f"Public FX API(s) {non_back_compat_strs} introduced but not given a "
2022-08-11T16:10:04.6382085Z AssertionError: Public FX API(s) ['torch.fx.passes.reinplace.reinplace'] introduced but not given a backwards-compatibility classification! Please decorate these API(s) with `@torch.fx._compatibility.compatibility` to specify BC guarantees.
2022-08-11T16:10:04.6382801Z 
2022-08-11T16:10:04.6383146Z ----------------------------------------------------------------------
2022-08-11T16:10:04.6383543Z Ran 1100 tests in 196.020s
2022-08-11T16:10:04.6383742Z 
2022-08-11T16:10:04.6383935Z FAILED (failures=1, skipped=194, expected failures=5)

See GitHub Actions build pull / linux-bionic-py3.7-clang9 / test (crossref, 1, 2, linux.2xlarge) (11/14)

Step: "Test" (full log | diagnosis details)

2022-08-11T16:10:03.1952869Z FAIL [0.002s]: tes..._surface (__main__.TestFXAPIBackwardCompatibility)
2022-08-11T16:10:03.1311795Z   test_type_check_reshape_false (fx.test_gradual_type.TypeCheckerTest) ... ok (0.003s)
2022-08-11T16:10:03.1342862Z   test_type_check_reshape_true (fx.test_gradual_type.TypeCheckerTest) ... ok (0.003s)
2022-08-11T16:10:03.1511389Z   test_type_check_symbolic_inferenceconv2D_maxpool2d_flatten (fx.test_gradual_type.TypeCheckerTest) ... ok (0.017s)
2022-08-11T16:10:03.1543639Z   test_type_check_transpose_False (fx.test_gradual_type.TypeCheckerTest) ... ok (0.003s)
2022-08-11T16:10:03.1577731Z   test_type_check_transpose_true (fx.test_gradual_type.TypeCheckerTest) ... ok (0.003s)
2022-08-11T16:10:03.1833388Z   test_type_maxpool2d_fully_static (fx.test_gradual_type.TypeCheckerTest) ... ok (0.025s)
2022-08-11T16:10:03.1871567Z   test_type_typechecl_maxpool2d_3dinput (fx.test_gradual_type.TypeCheckerTest) ... ok (0.004s)
2022-08-11T16:10:03.1949577Z   test_typecheck_basicblock (fx.test_gradual_type.TypeCheckerTest) ... ok (0.008s)
2022-08-11T16:10:03.1950087Z 
2022-08-11T16:10:03.1952340Z ======================================================================
2022-08-11T16:10:03.1952869Z FAIL [0.002s]: test_public_api_surface (__main__.TestFXAPIBackwardCompatibility)
2022-08-11T16:10:03.1953619Z ----------------------------------------------------------------------
2022-08-11T16:10:03.1954059Z Traceback (most recent call last):
2022-08-11T16:10:03.1954474Z   File "test_fx.py", line 3814, in test_public_api_surface
2022-08-11T16:10:03.1954986Z     raise AssertionError(f"Public FX API(s) {non_back_compat_strs} introduced but not given a "
2022-08-11T16:10:03.1956426Z AssertionError: Public FX API(s) ['torch.fx.passes.reinplace.reinplace'] introduced but not given a backwards-compatibility classification! Please decorate these API(s) with `@torch.fx._compatibility.compatibility` to specify BC guarantees.
2022-08-11T16:10:03.1957090Z 
2022-08-11T16:10:03.1957431Z ----------------------------------------------------------------------
2022-08-11T16:10:03.1957826Z Ran 1100 tests in 200.292s
2022-08-11T16:10:03.1958017Z 
2022-08-11T16:10:03.1958208Z FAILED (failures=1, skipped=194, expected failures=5)

See GitHub Actions build pull / linux-bionic-py3.7-clang9 / test (dynamo, 2, 2, linux.2xlarge) (12/14)

Step: "Test" (full log | diagnosis details)

2022-08-11T16:06:40.3463354Z RuntimeError: test_functionalization failed!
2022-08-11T16:06:39.1496106Z Executing ['/opt/conda/bin/python', '-bb', 'test_functionalization.py', '-v', '--import-slow-tests', '--import-disabled-tests'] ... [2022-08-11 16:06:39.149231]
2022-08-11T16:06:40.1904168Z Traceback (most recent call last):
2022-08-11T16:06:40.1904654Z   File "test_functionalization.py", line 8, in <module>
2022-08-11T16:06:40.1904941Z     from functorch.experimental import functionalize
2022-08-11T16:06:40.1905339Z ModuleNotFoundError: No module named 'functorch'
2022-08-11T16:06:40.3458378Z Traceback (most recent call last):
2022-08-11T16:06:40.3458660Z   File "test/run_test.py", line 974, in <module>
2022-08-11T16:06:40.3460484Z     main()
2022-08-11T16:06:40.3460684Z   File "test/run_test.py", line 952, in main
2022-08-11T16:06:40.3463121Z     raise RuntimeError(err_message)
2022-08-11T16:06:40.3463354Z RuntimeError: test_functionalization failed!
2022-08-11T16:06:40.5913230Z 
2022-08-11T16:06:40.5913550Z real	0m17.392s
2022-08-11T16:06:40.5913806Z user	0m16.961s
2022-08-11T16:06:40.5913993Z sys	0m3.188s
2022-08-11T16:06:40.5943049Z ##[error]Process completed with exit code 1.
2022-08-11T16:06:40.5977739Z Prepare all required actions
2022-08-11T16:06:40.5978026Z Getting action download info
2022-08-11T16:06:40.7893433Z ##[group]Run ./.github/actions/get-workflow-job-id
2022-08-11T16:06:40.7893647Z with:
2022-08-11T16:06:40.7893971Z   github-token: ***

See GitHub Actions build pull / linux-focal-py3.7-clang7-asan / test (default, 4, 5, linux.2xlarge) (13/14)

Step: "Test" (full log | diagnosis details)

2022-08-11T16:05:12.9719245Z FAIL [0.003s]: tes..._surface (__main__.TestFXAPIBackwardCompatibility)
2022-08-11T16:05:12.8929672Z   test_type_check_reshape_false (fx.test_gradual_type.TypeCheckerTest) ... ok (0.003s)
2022-08-11T16:05:12.8967550Z   test_type_check_reshape_true (fx.test_gradual_type.TypeCheckerTest) ... ok (0.004s)
2022-08-11T16:05:12.9165247Z   test_type_check_symbolic_inferenceconv2D_maxpool2d_flatten (fx.test_gradual_type.TypeCheckerTest) ... ok (0.020s)
2022-08-11T16:05:12.9206865Z   test_type_check_transpose_False (fx.test_gradual_type.TypeCheckerTest) ... ok (0.004s)
2022-08-11T16:05:12.9247151Z   test_type_check_transpose_true (fx.test_gradual_type.TypeCheckerTest) ... ok (0.004s)
2022-08-11T16:05:12.9565174Z   test_type_maxpool2d_fully_static (fx.test_gradual_type.TypeCheckerTest) ... ok (0.032s)
2022-08-11T16:05:12.9610121Z   test_type_typechecl_maxpool2d_3dinput (fx.test_gradual_type.TypeCheckerTest) ... ok (0.005s)
2022-08-11T16:05:12.9718039Z   test_typecheck_basicblock (fx.test_gradual_type.TypeCheckerTest) ... ok (0.011s)
2022-08-11T16:05:12.9718397Z 
2022-08-11T16:05:12.9718809Z ======================================================================
2022-08-11T16:05:12.9719245Z FAIL [0.003s]: test_public_api_surface (__main__.TestFXAPIBackwardCompatibility)
2022-08-11T16:05:12.9719999Z ----------------------------------------------------------------------
2022-08-11T16:05:12.9720425Z Traceback (most recent call last):
2022-08-11T16:05:12.9720813Z   File "test_fx.py", line 3814, in test_public_api_surface
2022-08-11T16:05:12.9721224Z     raise AssertionError(f"Public FX API(s) {non_back_compat_strs} introduced but not given a "
2022-08-11T16:05:12.9722264Z AssertionError: Public FX API(s) ['torch.fx.passes.reinplace.reinplace'] introduced but not given a backwards-compatibility classification! Please decorate these API(s) with `@torch.fx._compatibility.compatibility` to specify BC guarantees.
2022-08-11T16:05:12.9722746Z 
2022-08-11T16:05:12.9722999Z ----------------------------------------------------------------------
2022-08-11T16:05:12.9723244Z Ran 1002 tests in 12.146s
2022-08-11T16:05:12.9723365Z 
2022-08-11T16:05:12.9723507Z FAILED (failures=1, skipped=198, expected failures=5)

See GitHub Actions build pull / linux-bionic-cuda11.6-py3.10-gcc7 / test (functorch, 1, 1, linux.4xlarge.nvidia.gpu) (14/14)

Step: "Test" (full log | diagnosis details)

2022-08-11T16:24:17.3806269Z ModuleNotFoundError: No module named 'functorch'
2022-08-11T16:24:15.9545715Z prioritized: []
2022-08-11T16:24:15.9547664Z the rest: ['/var/lib/jenkins/workspace/functorch/test/test_compile_cache', '/var/lib/jenkins/workspace/functorch/test/test_dims', '/var/lib/jenkins/workspace/functorch/test/test_eager_transforms', '/var/lib/jenkins/workspace/functorch/test/test_functionalize', '/var/lib/jenkins/workspace/functorch/test/test_memory_efficient_fusion', '/var/lib/jenkins/workspace/functorch/test/test_minifier', '/var/lib/jenkins/workspace/functorch/test/test_ops', '/var/lib/jenkins/workspace/functorch/test/test_pythonkey', '/var/lib/jenkins/workspace/functorch/test/test_vmap']
2022-08-11T16:24:15.9548589Z 
2022-08-11T16:24:15.9549135Z Downloading https://raw.githubusercontent.com/pytorch/test-infra/generated-stats/stats/slow-tests.json to /var/lib/jenkins/workspace/test/.pytorch-slow-tests.json
2022-08-11T16:24:15.9766456Z Downloading https://raw.githubusercontent.com/pytorch/test-infra/generated-stats/stats/disabled-tests.json to /var/lib/jenkins/workspace/test/.pytorch-disabled-tests.json
2022-08-11T16:24:16.0167509Z Running /var/lib/jenkins/workspace/functorch/test/test_compile_cache ... [2022-08-11 16:24:16.016288]
2022-08-11T16:24:16.0168725Z Executing ['/opt/conda/bin/python', '-bb', '/var/lib/jenkins/workspace/functorch/test/test_compile_cache.py', '-v', '--import-slow-tests', '--import-disabled-tests'] ... [2022-08-11 16:24:16.016351]
2022-08-11T16:24:17.3804351Z Traceback (most recent call last):
2022-08-11T16:24:17.3805341Z   File "/var/lib/jenkins/workspace/functorch/test/test_compile_cache.py", line 5, in <module>
2022-08-11T16:24:17.3805801Z     import functorch
2022-08-11T16:24:17.3806269Z ModuleNotFoundError: No module named 'functorch'
2022-08-11T16:24:17.5790252Z Traceback (most recent call last):
2022-08-11T16:24:17.5791285Z   File "/var/lib/jenkins/workspace/test/run_test.py", line 974, in <module>
2022-08-11T16:24:17.5792232Z     main()
2022-08-11T16:24:17.5792966Z   File "/var/lib/jenkins/workspace/test/run_test.py", line 952, in main
2022-08-11T16:24:17.5794117Z     raise RuntimeError(err_message)
2022-08-11T16:24:17.5794992Z RuntimeError: /var/lib/jenkins/workspace/functorch/test/test_compile_cache failed!
2022-08-11T16:24:17.8474900Z ##[error]Process completed with exit code 1.
2022-08-11T16:24:17.8512254Z Prepare all required actions
2022-08-11T16:24:17.8512629Z Getting action download info
2022-08-11T16:24:18.0205279Z ##[group]Run ./.github/actions/get-workflow-job-id

This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

bdhirsh added a commit that referenced this pull request Aug 1, 2022
ghstack-source-id: 3749d21
Pull Request resolved: #82602
bdhirsh added a commit that referenced this pull request Aug 1, 2022
ghstack-source-id: 1c27041
Pull Request resolved: #82602
@bdhirsh bdhirsh requested review from Chillee and ezyang August 1, 2022 21:53
@ezyang
Copy link
Contributor

ezyang commented Aug 2, 2022

It would be really nice to see a before/after print

# these mutations into an opaque submodule
# so our graph infra can assume a functional graph.
if config.use_functionalize:
move_input_mutations_into_submodule(fw_module)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So if you do this after, partition_fn must be able to deal with mutations in fx_g. Can it, @Chillee ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I re-remembered that it can't haha, so I'll have to change this. the min-cut partition function code calls eliminate_dead_code(), so we want the copy_() ops to be hidden away before then.

@ezyang ezyang requested a review from SherlockNoMad August 2, 2022 01:09
module_node = fx_g.graph.call_module(
submodule_name,
args=tuple(node_to_placeholder.keys()),
kwargs=None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know we debated a bit possible representations, but this particular rep is actually not one that I had been thinking about. Mutation be in a submodule in the original graph is not great, because it still means the outer graph is not functional! (If you call a mutating function inside a graph, that makes your graph mutating.) I feel like there's an obligation for this submodule call to (somehow) not live in the Graph itself, as that makes it vulnerable to DCE again.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmmm I see. The alternative is probably just to leave it out of the fx.Graph (but still keep it in the GraphModule), and then manually cal the submodule later on in AOTAutograd, after the compiled forward gets executed? That doesn't seem too bad.

I guess my thought was that anything that operates on the fx.Graph (DCE + compilers) would just see a custom submodule and know to treat it as an opaque object. But that's probably not totally right - we can't really enforce that a graph pass won't move the submodule around earlier in the graph (which would be wrong).

Copy link
Contributor Author

@bdhirsh bdhirsh Aug 2, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

although, if I want to be able to invoke the submodule separately outside of the graph, I'll need to somehow get the inputs for the submodule - probably by updating the original graph to make them additional outputs. Extra complexity (since I'll also need to remove the pytree stuff in the graph to do that), but doesn't seem too bad.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can always write passes so they operate correctly in the presence of mutation. But we don't want to force it on pass writers. Opaque submodules at final lowering are unlikely to be doing heavy optimization, so it's easier to deal with arbitrary side effects. But even, e.g., finding fusion groups, is currently not correct with mutating things (though mostly this is because of DCE calls). But the problem here is you're adding in this module all the way at the beginning, before all of the optimization passes, and so you're forcing them to, at the very least, know not to remove your module.

Extra outputs is what I would expect to see if it's external.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could also have it not actually be external, but stored on the GraphModule and not actually part of the graph. But passes would need to know to propagate it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm yeah, having it actually be external feels better. I'll try that.

When functionalization is turned on in AOT Autograd, we want to hide input mutations in the graph so that the backend compiler doesn't need to worry about seeing `copy_()` ops in the graph. This PR does that by hiding it in an opaque submodule.

Right now this logic happens after the partitioning, and we're relying on partitioning to always leave the `copy_()` nodes in the forward graph (which... probably needs some more testing, but I think is fine?).

I added light testing for this pass by including it in the existing `test_functionalization.py` tests, but I'm planning to try hooking this into the torchbench suite, which will let us get rid of this code: https://github.com/pytorch/torchdynamo/blob/5040d49795dde35f0112e27a6744015d44318deb/torchdynamo/optimizations/training.py#L59




[ghstack-poisoned]
When functionalization is turned on in AOT Autograd, we want to hide input mutations in the graph so that the backend compiler doesn't need to worry about seeing `copy_()` ops in the graph. This PR does that by hiding it in an opaque submodule.

Right now this logic happens after the partitioning, and we're relying on partitioning to always leave the `copy_()` nodes in the forward graph (which... probably needs some more testing, but I think is fine?).

I added light testing for this pass by including it in the existing `test_functionalization.py` tests, but I'm planning to try hooking this into the torchbench suite, which will let us get rid of this code: https://github.com/pytorch/torchdynamo/blob/5040d49795dde35f0112e27a6744015d44318deb/torchdynamo/optimizations/training.py#L59




[ghstack-poisoned]
bdhirsh added a commit that referenced this pull request Aug 3, 2022
ghstack-source-id: cfa370a
Pull Request resolved: #82602
input_clone = inpt.clone()
input_clone2 = inpt.clone()
input_clone3 = inpt.clone()
input_clone4 = inpt.clone()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The changes in this file are just minor QoL changes

@bdhirsh
Copy link
Contributor Author

bdhirsh commented Aug 3, 2022

Ok, I beefed up this PR (so mutations get taken out before parititioning) and added better testing. @Chillee lmk what you think of the AOTAutograd changes.

I also dumped my tests in test_compile_cache.py since it was the only file I found that directly tests aot_function(), but I can move it somewhere else if preferred.

I think what we probably want is that once this lands, we should turn this + re-inplacing on by default in the benchmark testing on torchbench + timm + hugging face (cc @anijain2305). I'm still working on getting that set up so I can actually run it locally and confirm that nothing breaks.

@bdhirsh
Copy link
Contributor Author

bdhirsh commented Aug 3, 2022

It looks like this pass isn't playing well with the partitioning code - it fails python test/test_pythonkey.py TestAOTAutograd.test_batchnorm when functionalization is toggled on. I think the partitioning code needs some beefing up but I'm taking a closer look.

When functionalization is turned on in AOT Autograd, we want to hide input mutations in the graph so that the backend compiler doesn't need to worry about seeing `copy_()` ops in the graph. This PR does that by hiding it in an opaque submodule.

Right now this logic happens after the partitioning, and we're relying on partitioning to always leave the `copy_()` nodes in the forward graph (which... probably needs some more testing, but I think is fine?).

I added light testing for this pass by including it in the existing `test_functionalization.py` tests, but I'm planning to try hooking this into the torchbench suite, which will let us get rid of this code: https://github.com/pytorch/torchdynamo/blob/5040d49795dde35f0112e27a6744015d44318deb/torchdynamo/optimizations/training.py#L59




[ghstack-poisoned]
bdhirsh added a commit that referenced this pull request Aug 11, 2022
ghstack-source-id: 7abca74
Pull Request resolved: #82602
When functionalization is turned on in AOT Autograd, we want to hide input mutations in the graph so that the backend compiler doesn't need to worry about seeing `copy_()` ops in the graph. This PR does that by hiding it in an opaque submodule.

Right now this logic happens after the partitioning, and we're relying on partitioning to always leave the `copy_()` nodes in the forward graph (which... probably needs some more testing, but I think is fine?).

I added light testing for this pass by including it in the existing `test_functionalization.py` tests, but I'm planning to try hooking this into the torchbench suite, which will let us get rid of this code: https://github.com/pytorch/torchdynamo/blob/5040d49795dde35f0112e27a6744015d44318deb/torchdynamo/optimizations/training.py#L59




[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Sep 26, 2022

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/82602

Note: Links to docs will display an error until the docs builds have been completed.

❌ 7 Failures, 2 Pending

As of commit 83bdef4:

The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

bdhirsh added a commit that referenced this pull request Sep 26, 2022
ghstack-source-id: a72bc8d
Pull Request resolved: #82602
@bdhirsh
Copy link
Contributor Author

bdhirsh commented Sep 26, 2022

I'm pushing some more on this based on #85036, since adding an epilogue to AOTAutograd should unblock a few models that were previously hitting dynamo's fallback. This isn't ready for review yet though - waiting to sanity check some passes tests first.

When functionalization is turned on in AOT Autograd, we want to hide input mutations in the graph so that the backend compiler doesn't need to worry about seeing `copy_()` ops in the graph. This PR does that by hiding it in an opaque submodule.

Right now this logic happens after the partitioning, and we're relying on partitioning to always leave the `copy_()` nodes in the forward graph (which... probably needs some more testing, but I think is fine?).

I added light testing for this pass by including it in the existing `test_functionalization.py` tests, but I'm planning to try hooking this into the torchbench suite, which will let us get rid of this code: https://github.com/pytorch/torchdynamo/blob/5040d49795dde35f0112e27a6744015d44318deb/torchdynamo/optimizations/training.py#L59




[ghstack-poisoned]
When functionalization is turned on in AOT Autograd, we want to hide input mutations in the graph so that the backend compiler doesn't need to worry about seeing `copy_()` ops in the graph. This PR does that by hiding it in an opaque submodule.

Right now this logic happens after the partitioning, and we're relying on partitioning to always leave the `copy_()` nodes in the forward graph (which... probably needs some more testing, but I think is fine?).

I added light testing for this pass by including it in the existing `test_functionalization.py` tests, but I'm planning to try hooking this into the torchbench suite, which will let us get rid of this code: https://github.com/pytorch/torchdynamo/blob/5040d49795dde35f0112e27a6744015d44318deb/torchdynamo/optimizations/training.py#L59




[ghstack-poisoned]
When functionalization is turned on in AOT Autograd, we want to hide input mutations in the graph so that the backend compiler doesn't need to worry about seeing `copy_()` ops in the graph. This PR does that by hiding it in an opaque submodule.

Right now this logic happens after the partitioning, and we're relying on partitioning to always leave the `copy_()` nodes in the forward graph (which... probably needs some more testing, but I think is fine?).

I added light testing for this pass by including it in the existing `test_functionalization.py` tests, but I'm planning to try hooking this into the torchbench suite, which will let us get rid of this code: https://github.com/pytorch/torchdynamo/blob/5040d49795dde35f0112e27a6744015d44318deb/torchdynamo/optimizations/training.py#L59




[ghstack-poisoned]
When functionalization is turned on in AOT Autograd, we want to hide input mutations in the graph so that the backend compiler doesn't need to worry about seeing `copy_()` ops in the graph. This PR does that by hiding it in an opaque submodule.

Right now this logic happens after the partitioning, and we're relying on partitioning to always leave the `copy_()` nodes in the forward graph (which... probably needs some more testing, but I think is fine?).

I added light testing for this pass by including it in the existing `test_functionalization.py` tests, but I'm planning to try hooking this into the torchbench suite, which will let us get rid of this code: https://github.com/pytorch/torchdynamo/blob/5040d49795dde35f0112e27a6744015d44318deb/torchdynamo/optimizations/training.py#L59




[ghstack-poisoned]
@bdhirsh bdhirsh mentioned this pull request Sep 29, 2022
When functionalization is turned on in AOT Autograd, we want to hide input mutations in the graph so that the backend compiler doesn't need to worry about seeing `copy_()` ops in the graph. This PR does that by hiding it in an opaque submodule.

Right now this logic happens after the partitioning, and we're relying on partitioning to always leave the `copy_()` nodes in the forward graph (which... probably needs some more testing, but I think is fine?).

I added light testing for this pass by including it in the existing `test_functionalization.py` tests, but I'm planning to try hooking this into the torchbench suite, which will let us get rid of this code: https://github.com/pytorch/torchdynamo/blob/5040d49795dde35f0112e27a6744015d44318deb/torchdynamo/optimizations/training.py#L59




[ghstack-poisoned]
bdhirsh added a commit that referenced this pull request Sep 30, 2022
ghstack-source-id: 1df3abb
Pull Request resolved: #82602
@bdhirsh
Copy link
Contributor Author

bdhirsh commented Sep 30, 2022

@Chillee @ezyang this should be ready for another round. Some high level questions/notes:

(a) What do you all think of the partitioner calling convention changes?

(b) Should functionalization (and the same epilogue infra) be refactored to also run in the aot_dispatch_base() case? It looks like.. functionalization doesn’t even run there today (I guess I'm surprised that it isn't breaking anything?)

(c) There are a bunch of test failures that are due to fake tensors not being turned on (we need to know requires_grad-ness of the outputs to know what to mark properly in the autograd.function, but that information is wrong if you're using make_fx() with normal tracing, since we store a TensorMeta object in the FX graph that never gets its requires_grad field properly fixed up later)

@ezyang
Copy link
Contributor

ezyang commented Sep 30, 2022

I'm deferring to @Chillee for this.

@facebook-github-bot
Copy link
Contributor

/easycla

As part of the transition to the PyTorch Foundation, this project now requires contributions be covered under the new CLA. See #85559 for additional details.

This comment will trigger a new check of this PR. If you are already covered, you will simply see a new "EasyCLA" check that passes. If you are not covered, a bot will leave a new comment with a link to sign.

@bdhirsh
Copy link
Contributor Author

bdhirsh commented Oct 21, 2022

@Chillee (and also @wconstab, since you're familiar with the partitioner) friendly bump.

I still need to rebase this though, on top of the recent partitioner changes.

mutated_input_args = [x for pair in zip(original_inputs_needing_mutation, mutated_inputs) for x in pair]
# TODO: this epilogue should also be responsible for generating outputs
# that are aliases of inputs.
input_mutation_epilogue(*mutated_input_args)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So after the pass, all mutation would be in opaque mutation epilogue and backends lose visibility there.
We are missing out on fusion opportunities here.

Would we be able to opt-in to inline the mutation epilogue back into the main graph?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So after the pass, all mutation would be in opaque mutation epilogue

TBC, this is only for captured graphs with input mutations. Intermediate mutations in a graph would have already been removed by functionalization.

The idea in this PR as its stands is that instead of the backend seeing an operator like:

native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor(a!) running_mean, Tensor(a!) running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)

It'll see a purely functional version, that returns the updated inputs instead of mutating them directly:

native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor running_mean, Tensor running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor, Tensor, Tensor)

And if mutated inputs to that operator happen to correspond to graph inputs (which is true for running_mean/running_var variables for batch norm), then at the end of the graph we'll have an epilogue that copies the "updated inputs" back to the original inputs:

def compiled_fn(inpt1, inpt2):
     # ... first run the entirely functional compiled function
     outs, mutated_inpt1, mutated_inpt2 = real_compiled_fn(inpt1, inpt2)
     inpt1.copy_(mutated_inpt1)
     inpt2.copy_(mutated_inpt2)
     return outs

Would we be able to opt-in to inline the mutation epilogue back into the main graph?

I remember @Chillee bringing this up before - we probably can? Although for now, this PR doesn't do that and just ensures that we do the "correct" thing first.

It's also not clear to me - why would this prevent fusions? The main disadvantage as I see it is that you can end up using more memory - you have to keep the buffer for both the original input and the updated input around while the compiled function is running.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the explanation, I think we are mostly on the same page here. I'm mostly asking for mutation on inputs, since that's what normalization layers uses for running stats update.

It's also not clear to me - why would this prevent fusions?

It's the epilogue in-place copies that we are missing out. Since those are cheap and easy to handle in normalization kernels. i.e. If we keep those in the epilogue, they won't be visible to fuser backend.

Copy link
Contributor

@wconstab wconstab left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks correct to me. thanks @bdhirsh

@bdhirsh
Copy link
Contributor Author

bdhirsh commented Nov 7, 2022

I'm actually going to close this PR and create a fresh one. Thanks for the stamp and sorry about that Will. This is mostly because:

(1) Ed's "trace with functionalization in one pass" PR has landed and changed how we want this PR to work pretty substantially - we no longer have to back copy_() and as_strided_() into the graph, and then analyze the graph. We can just directly detect the input mutation case from inside of AOTAutograd and add the proper handling

(2) There are a bunch of other edge cases that are worth thinking about more holistically. Ed has a great doc on them here: https://docs.google.com/document/d/19UoIh_SVrMy_b2Sx5ZaeOJttm6P0Qmyss2rdBuyfoic/edit?usp=sharing

@github-actions
Copy link
Contributor

github-actions bot commented Jan 6, 2023

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a mistake.
If you are unable to remove the Stale label please contact a maintainer in order to do so.
If you want the bot to never mark this PR stale again, add the no-stale label.
Stale pull requests will automatically be closed after 30 days of inactivity.

@github-actions github-actions bot added the Stale label Jan 6, 2023
@github-actions github-actions bot closed this Feb 5, 2023
@facebook-github-bot facebook-github-bot deleted the gh/bdhirsh/289/head branch June 8, 2023 15:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants