-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[ROCm] enable unit tests and other changes #10266
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
first round of changes to update PR
merge from upstream
This reverts commit 864dbe4.
next round of fixes to address comments
merge from upstream
After discussion in review, disable flake8 on pyHIPIFY for now.
Merge from pytorch upstream
This will hopefully safe some grief in the future with overriding code.
…solete code for output directory not existing
Address more review comments
We already had a fallback.
Automatically handle transpilations inside device code only
Automatically pre-include CUDA headers just like NVCC.
Minor changes to pass flake8 tests
| pip install -r requirements.txt || true | ||
|
|
||
| if [[ "$BUILD_ENVIRONMENT" == *rocm* ]]; then | ||
| # This is necessary in order to cross compile (or else we'll have missing GPU device). |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| __launch_bounds__(nt, 4) | ||
| #ifdef __HIP_PLATFORM_HCC__ | ||
| __global__ void elementwise_kernel(int N, const func_t& f) { | ||
| #else |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| HIP_INCLUDE_DIRECTORIES(${Caffe2_HIP_INCLUDES}) | ||
| ENDIF() | ||
| if(BUILD_ATEN) | ||
| # Get Compile Definitions from the directory (FindHIP.CMake bug) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| with self.assertRaises(RuntimeError): | ||
| b.add_(5) | ||
|
|
||
| @unittest.skipIf(TEST_WITH_ROCM, "test doesn't currently work on the ROCm stack") |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| """Generalization for finding a balancing closure group | ||
| e.g. if group = ["(", ")"], then finds the first balanced parantheses. | ||
| if group = ["{", "}"], then finds the first balanced bracket. |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| """If the file makes kernel builtin calls and does not include the cuda_runtime.h header, | ||
| then automatically add an #include to match the "magic" includes provided by NVCC. | ||
| TODO: | ||
| Update logic to ignore cases where the cuda_runtime.h is included by another file. |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| else | ||
| LDFLAGS="$LDFLAGS -Wl,-rpath,\$ORIGIN" | ||
| if [[ $USE_ROCM -eq 1 ]]; then | ||
| LDFLAGS="$LDFLAGS -Wl,-rpath,\\\\\\\$ORIGIN" |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
ezyang
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some nits but I don't see any show stoppers.
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ezyang has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Summary: This PR for the ROCm target does the following: * enable some unit tests on ROCm * fix a missing static_cast that breaks BatchNorm call on ROCm * fix BatchNorm to work on ROCm w/ ROCm warp sizes etc * improve the pyhipify script by introducing kernel scope to some transpilations and other improvements * fix a linking issue on ROCm * for more unit test sets: mark currently broken tests broken (to be fixed) * enable THINLTO (phase one) to parallelize linking * address the first failing of the elementwise kernel by removing non-working ROCm specialization Pull Request resolved: pytorch/pytorch#10266 Differential Revision: D9184178 Pulled By: ezyang fbshipit-source-id: 03bcd1fe4ca4dd3241f09634dbd42b6a4c350297
…RAND_PR While there, add the remaining changes requested in upstream PR pytorch#10266
| ('gesv', (2, 3, S, S), ((2, 3, S, S),), 'batched_dims', NO_ARGS, [skipIfNoLapack]), | ||
| ('gesv', (2, 2, S, S), ((1, S, S),), 'batched_broadcast_A', NO_ARGS, [skipIfNoLapack]), | ||
| ('gesv', (1, S, S), ((2, 2, S, S),), 'batched_broadcast_b', NO_ARGS, [skipIfNoLapack]), | ||
| ('gesv', (S, S, S), ((S, S, S),), 'batched', NO_ARGS, [skipIfNoLapack, skipIfRocm]), |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
Summary: This PR for the ROCm target does the following: * enable some unit tests on ROCm * fix a missing static_cast that breaks BatchNorm call on ROCm * fix BatchNorm to work on ROCm w/ ROCm warp sizes etc * improve the pyhipify script by introducing kernel scope to some transpilations and other improvements * fix a linking issue on ROCm * for more unit test sets: mark currently broken tests broken (to be fixed) * enable THINLTO (phase one) to parallelize linking * address the first failing of the elementwise kernel by removing non-working ROCm specialization Pull Request resolved: pytorch#10266 Differential Revision: D9184178 Pulled By: ezyang fbshipit-source-id: 03bcd1fe4ca4dd3241f09634dbd42b6a4c350297
This PR for the ROCm target does the following: