Skip to content

Conversation

@iotamudelta
Copy link
Contributor

This PR for the ROCm target does the following:

  • enable some unit tests on ROCm
  • fix a missing static_cast that breaks BatchNorm call on ROCm
  • fix BatchNorm to work on ROCm w/ ROCm warp sizes etc
  • improve the pyhipify script by introducing kernel scope to some transpilations and other improvements
  • fix a linking issue on ROCm
  • for more unit test sets: mark currently broken tests broken (to be fixed)
  • enable THINLTO (phase one) to parallelize linking
  • address the first failing of the elementwise kernel by removing non-working ROCm specialization

iotamudelta and others added 30 commits July 2, 2018 13:46
first round of changes to update PR
This reverts commit 864dbe4.
next round of fixes to address comments
After discussion in review, disable flake8 on pyHIPIFY for now.
Merge from pytorch upstream
This will hopefully safe some grief in the future with overriding code.
…solete code for output directory not existing
Address more review comments
iotamudelta and others added 3 commits August 6, 2018 12:55
Automatically handle transpilations inside device code only
Automatically pre-include CUDA headers just like NVCC.
pip install -r requirements.txt || true

if [[ "$BUILD_ENVIRONMENT" == *rocm* ]]; then
# This is necessary in order to cross compile (or else we'll have missing GPU device).

This comment was marked as off-topic.

This comment was marked as off-topic.

__launch_bounds__(nt, 4)
#ifdef __HIP_PLATFORM_HCC__
__global__ void elementwise_kernel(int N, const func_t& f) {
#else

This comment was marked as off-topic.

HIP_INCLUDE_DIRECTORIES(${Caffe2_HIP_INCLUDES})
ENDIF()
if(BUILD_ATEN)
# Get Compile Definitions from the directory (FindHIP.CMake bug)

This comment was marked as off-topic.

with self.assertRaises(RuntimeError):
b.add_(5)

@unittest.skipIf(TEST_WITH_ROCM, "test doesn't currently work on the ROCm stack")

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

"""Generalization for finding a balancing closure group
e.g. if group = ["(", ")"], then finds the first balanced parantheses.
if group = ["{", "}"], then finds the first balanced bracket.

This comment was marked as off-topic.

This comment was marked as off-topic.

"""If the file makes kernel builtin calls and does not include the cuda_runtime.h header,
then automatically add an #include to match the "magic" includes provided by NVCC.
TODO:
Update logic to ignore cases where the cuda_runtime.h is included by another file.

This comment was marked as off-topic.

This comment was marked as off-topic.

else
LDFLAGS="$LDFLAGS -Wl,-rpath,\$ORIGIN"
if [[ $USE_ROCM -eq 1 ]]; then
LDFLAGS="$LDFLAGS -Wl,-rpath,\\\\\\\$ORIGIN"

This comment was marked as off-topic.

This comment was marked as off-topic.

Copy link
Contributor

@ezyang ezyang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some nits but I don't see any show stoppers.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ezyang has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

zdevito pushed a commit to zdevito/ATen that referenced this pull request Aug 6, 2018
Summary:
This PR for the ROCm target does the following:
* enable some unit tests on ROCm
* fix a missing static_cast that breaks BatchNorm call on ROCm
* fix BatchNorm to work on ROCm w/ ROCm warp sizes etc
* improve the pyhipify script by introducing kernel scope to some transpilations and other improvements
* fix a linking issue on ROCm
* for more unit test sets: mark currently broken tests broken (to be fixed)
* enable THINLTO (phase one) to parallelize linking
* address the first failing of the elementwise kernel by removing non-working ROCm specialization
Pull Request resolved: pytorch/pytorch#10266

Differential Revision: D9184178

Pulled By: ezyang

fbshipit-source-id: 03bcd1fe4ca4dd3241f09634dbd42b6a4c350297
iotamudelta added a commit to iotamudelta/pytorch that referenced this pull request Aug 7, 2018
…RAND_PR

While there, add the remaining changes requested in upstream PR pytorch#10266
('gesv', (2, 3, S, S), ((2, 3, S, S),), 'batched_dims', NO_ARGS, [skipIfNoLapack]),
('gesv', (2, 2, S, S), ((1, S, S),), 'batched_broadcast_A', NO_ARGS, [skipIfNoLapack]),
('gesv', (1, S, S), ((2, 2, S, S),), 'batched_broadcast_b', NO_ARGS, [skipIfNoLapack]),
('gesv', (S, S, S), ((S, S, S),), 'batched', NO_ARGS, [skipIfNoLapack, skipIfRocm]),

This comment was marked as off-topic.

goodlux pushed a commit to goodlux/pytorch that referenced this pull request Aug 15, 2018
Summary:
This PR for the ROCm target does the following:
* enable some unit tests on ROCm
* fix a missing static_cast that breaks BatchNorm call on ROCm
* fix BatchNorm to work on ROCm w/ ROCm warp sizes etc
* improve the pyhipify script by introducing kernel scope to some transpilations and other improvements
* fix a linking issue on ROCm
* for more unit test sets: mark currently broken tests broken (to be fixed)
* enable THINLTO (phase one) to parallelize linking
* address the first failing of the elementwise kernel by removing non-working ROCm specialization
Pull Request resolved: pytorch#10266

Differential Revision: D9184178

Pulled By: ezyang

fbshipit-source-id: 03bcd1fe4ca4dd3241f09634dbd42b6a4c350297
@jithunnair-amd jithunnair-amd deleted the enableunittests branch September 25, 2025 16:32
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants