Skip to content

[CUDA12] Make PyTorch compatible with CUDA 12#91118

Closed
jianyuh wants to merge 1 commit intopytorch:masterfrom
jianyuh:pytorch_cuda12
Closed

[CUDA12] Make PyTorch compatible with CUDA 12#91118
jianyuh wants to merge 1 commit intopytorch:masterfrom
jianyuh:pytorch_cuda12

Conversation

@jianyuh
Copy link
Member

@jianyuh jianyuh commented Dec 19, 2022

Fix the failure when building PyTorch from source code using CUDA 12

In file included from /home/jianyuhuang/Work/Github/pytorch/c10/cuda/CUDAFunctions.h:12,
                 from /home/jianyuhuang/Work/Github/pytorch/c10/cuda/CUDAStream.h:10,
                 from /home/jianyuhuang/Work/Github/pytorch/c10/cuda/CUDAGraphsC10Utils.h:3,
                 from /home/jianyuhuang/Work/Github/pytorch/aten/src/ATen/cuda/CUDAGraph.h:5,
                 from /home/jianyuhuang/Work/Github/pytorch/aten/src/ATen/cuda/CUDAGraph.cpp:2:
/home/jianyuhuang/Work/Github/pytorch/aten/src/ATen/cuda/CUDAGraph.cpp: In member function ‘void at::cuda::CUDAGraph::capture_end()’:
/home/jianyuhuang/Work/Github/pytorch/aten/src/ATen/cuda/CUDAGraph.cpp:168:75: warning: converting to non-pointer type ‘long long unsigned int’ from NULL [-Wconversion-null]
     AT_CUDA_CHECK(cudaGraphInstantiate(&graph_exec_, graph_, NULL, NULL, 0));
                                                                           ^
/home/jianyuhuang/Work/Github/pytorch/c10/cuda/CUDAException.h:31:42: note: in definition of macro ‘C10_CUDA_CHECK’
     C10_UNUSED const cudaError_t __err = EXPR;                           \
                                          ^~~~
/home/jianyuhuang/Work/Github/pytorch/aten/src/ATen/cuda/CUDAGraph.cpp:168:5: note: in expansion of macro ‘AT_CUDA_CHECK’
     AT_CUDA_CHECK(cudaGraphInstantiate(&graph_exec_, graph_, NULL, NULL, 0));
     ^~~~~~~~~~~~~
/home/jianyuhuang/Work/Github/pytorch/aten/src/ATen/cuda/CUDAGraph.cpp:168:75: error: too many arguments to function ‘cudaError_t cudaGraphInstantiate(CUgraphExec_st**, cudaGraph_t, long long unsigned int)’
     AT_CUDA_CHECK(cudaGraphInstantiate(&graph_exec_, graph_, NULL, NULL, 0));
                                                                           ^
/home/jianyuhuang/Work/Github/pytorch/c10/cuda/CUDAException.h:31:42: note: in definition of macro ‘C10_CUDA_CHECK’
     C10_UNUSED const cudaError_t __err = EXPR;                           \
                                          ^~~~
/home/jianyuhuang/Work/Github/pytorch/aten/src/ATen/cuda/CUDAGraph.cpp:168:5: note: in expansion of macro ‘AT_CUDA_CHECK’
     AT_CUDA_CHECK(cudaGraphInstantiate(&graph_exec_, graph_, NULL, NULL, 0));
     ^~~~~~~~~~~~~
In file included from /home/jianyuhuang/Work/Github/pytorch/c10/cuda/CUDAStream.h:6,
                 from /home/jianyuhuang/Work/Github/pytorch/c10/cuda/CUDAGraphsC10Utils.h:3,
                 from /home/jianyuhuang/Work/Github/pytorch/aten/src/ATen/cuda/CUDAGraph.h:5,
                 from /home/jianyuhuang/Work/Github/pytorch/aten/src/ATen/cuda/CUDAGraph.cpp:2:
/usr/local/cuda/include/cuda_runtime_api.h:11439:39: note: declared here
 extern __host__ cudaError_t CUDARTAPI cudaGraphInstantiate(cudaGraphExec_t *pGraphExec, cudaGraph_t graph, unsigned long long flags __dv(0));
                                       ^~~~~~~~~~~~~~~~~~~~
ninja: build stopped: subcommand failed.
/home/jianyuhuang/Work/Github/pytorch/torch/csrc/cuda/shared/cudart.cpp: In function ‘void torch::cuda::shared::initCudartBindings(PyObject*)’:
/home/jianyuhuang/Work/Github/pytorch/torch/csrc/cuda/shared/cudart.cpp:34:13: error: ‘cudaOutputMode_t’ was not declared in this scope
   py::enum_<cudaOutputMode_t>(
             ^~~~~~~~~~~~~~~~
/home/jianyuhuang/Work/Github/pytorch/torch/csrc/cuda/shared/cudart.cpp:34:13: note: suggested alternative: ‘cudaGraphNode_t’
   py::enum_<cudaOutputMode_t>(
             ^~~~~~~~~~~~~~~~
             cudaGraphNode_t
/home/jianyuhuang/Work/Github/pytorch/torch/csrc/cuda/shared/cudart.cpp:34:29: error: template argument 1 is invalid
   py::enum_<cudaOutputMode_t>(
                             ^
/home/jianyuhuang/Work/Github/pytorch/torch/csrc/cuda/shared/cudart.cpp:38:30: error: ‘cudaKeyValuePair’ was not declared in this scope
       .value("KeyValuePair", cudaKeyValuePair)
                              ^~~~~~~~~~~~~~~~
/home/jianyuhuang/Work/Github/pytorch/torch/csrc/cuda/shared/cudart.cpp:39:21: error: ‘cudaCSV’ was not declared in this scope
       .value("CSV", cudaCSV);
                     ^~~~~~~
/home/jianyuhuang/Work/Github/pytorch/torch/csrc/cuda/shared/cudart.cpp:39:21: note: suggested alternative: ‘cudart’
       .value("CSV", cudaCSV);
                     ^~~~~~~
                     cudart
/home/jianyuhuang/Work/Github/pytorch/torch/csrc/cuda/shared/cudart.cpp:99:7: error: ‘cudaProfilerInitialize’ was not declared in this scope
       cudaProfilerInitialize);
       ^~~~~~~~~~~~~~~~~~~~~~
/home/jianyuhuang/Work/Github/pytorch/torch/csrc/cuda/shared/cudart.cpp:99:7: note: suggested alternative: ‘cudaProfilerStart’
       cudaProfilerInitialize);
       ^~~~~~~~~~~~~~~~~~~~~~
       cudaProfilerStart
ninja: build stopped: subcommand failed.

After these fixes, we can see CUDA 12 is successfully built with OSS PyTorch instructions.

USE_CUDA=1 python setup.py develop 2>&1 | tee compile.log

cc @ngimel

@pytorch-bot
Copy link

pytorch-bot bot commented Dec 19, 2022

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/91118

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 35f7265:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@jianyuh jianyuh requested a review from ngimel December 19, 2022 21:17
@ngimel
Copy link
Collaborator

ngimel commented Dec 19, 2022

fyi @ptrblck

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does cudaProfilerInitialize not work anymore?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Just add some fix.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and this would prevent profiler initialization?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Just add some fix.

Copy link
Collaborator

@ngimel ngimel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Stamp to unblock, but we should have a proper task listing things that need to be done for cuda 12.

@ngimel ngimel changed the title Make PyTorch compatible with CUDA 12 [CUDA 12] Make PyTorch compatible with CUDA 12 Dec 20, 2022
@ngimel ngimel changed the title [CUDA 12] Make PyTorch compatible with CUDA 12 [CUDA12] Make PyTorch compatible with CUDA 12 Dec 20, 2022
@facebook-github-bot
Copy link
Contributor

@jianyuh has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@jianyuh has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@jianyuh has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@jianyuh
Copy link
Member Author

jianyuh commented Dec 20, 2022

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Dec 20, 2022
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

The merge job was canceled. If you believe this is a mistake,then you can re trigger it through pytorch-bot.

@jianyuh
Copy link
Member Author

jianyuh commented Dec 20, 2022

@pytorchbot merge

@facebook-github-bot
Copy link
Contributor

@jianyuh has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@brad-mengchi brad-mengchi self-requested a review December 20, 2022 06:02
@pytorchmergebot
Copy link
Collaborator

The merge job was canceled. If you believe this is a mistake,then you can re trigger it through pytorch-bot.

@jianyuh
Copy link
Member Author

jianyuh commented Dec 20, 2022

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@facebook-github-bot
Copy link
Contributor

@jianyuh has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@jianyuh jianyuh added module: cuda Related to torch.cuda, and CUDA support in general release notes: cuda release notes category labels Feb 26, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged module: cuda Related to torch.cuda, and CUDA support in general release notes: cuda release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants