Skip to content

Commit a7d4fae

Browse files
jeffdailymalfet
authored andcommitted
[ROCm] set hipblas workspace (#138791)
Fixes #138532. This brings hipblas behavior in line with cublas behavior with respect to setting the workspace to an allocation from the caching allocator as well as the env var HIPBLAS_WORKSPACE_CONFIG. Pull Request resolved: #138791 Approved by: https://github.com/naromero77amd, https://github.com/eqy, https://github.com/malfet Co-authored-by: Nikita Shulga <[email protected]>
1 parent cc59e91 commit a7d4fae

File tree

4 files changed

+86
-17
lines changed

4 files changed

+86
-17
lines changed

aten/src/ATen/cuda/CublasHandlePool.cpp

Lines changed: 49 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,39 @@ void destroyCublasLtHandle(cublasLtHandle_t handle) {
4848
}
4949

5050
using CuBlasLtPoolType = DeviceThreadHandlePool<cublasLtHandle_t, createCublasLtHandle, destroyCublasLtHandle>;
51+
52+
// ugly hack until hipblasSetWorkspace exists
53+
#include <rocblas/rocblas.h>
54+
55+
static hipblasStatus_t rocBLASStatusToHIPStatus(rocblas_status error) {
56+
switch(error) {
57+
case rocblas_status_size_unchanged:
58+
case rocblas_status_size_increased:
59+
case rocblas_status_success:
60+
return HIPBLAS_STATUS_SUCCESS;
61+
case rocblas_status_invalid_handle:
62+
return HIPBLAS_STATUS_NOT_INITIALIZED;
63+
case rocblas_status_not_implemented:
64+
return HIPBLAS_STATUS_NOT_SUPPORTED;
65+
case rocblas_status_invalid_pointer:
66+
case rocblas_status_invalid_size:
67+
case rocblas_status_invalid_value:
68+
return HIPBLAS_STATUS_INVALID_VALUE;
69+
case rocblas_status_memory_error:
70+
return HIPBLAS_STATUS_ALLOC_FAILED;
71+
case rocblas_status_internal_error:
72+
return HIPBLAS_STATUS_INTERNAL_ERROR;
73+
}
74+
TORCH_CHECK(false, "HIPBLAS_STATUS_INVALID_ENUM");
75+
}
76+
77+
static hipblasStatus_t hipblasSetWorkspace_replacement(hipblasHandle_t handle, void* addr, size_t size) {
78+
return rocBLASStatusToHIPStatus(rocblas_set_workspace((rocblas_handle)handle, addr, size));
79+
}
80+
81+
// hipify mappings file correctly maps this but the function doesn't exist yet
82+
#define hipblasSetWorkspace hipblasSetWorkspace_replacement
83+
5184
#endif
5285

5386
std::map<std::tuple<void *, void *>, at::DataPtr>& cublas_handle_stream_to_workspace() {
@@ -77,17 +110,29 @@ using CuBlasPoolType = DeviceThreadHandlePool<cublasHandle_t, createCublasHandle
77110
} // namespace
78111

79112
void clearCublasWorkspaces() {
80-
#if !defined(USE_ROCM)
81-
cublas_handle_stream_to_workspace().clear();
82-
#endif
113+
cublas_handle_stream_to_workspace().clear();
83114
}
84115

85116
size_t parseChosenWorkspaceSize() {
86117
const char * val = getenv("CUBLAS_WORKSPACE_CONFIG");
118+
#ifdef USE_ROCM
119+
if (!val) {
120+
val = getenv("HIPBLAS_WORKSPACE_CONFIG");
121+
}
122+
if (!val) {
123+
// for extra convenience
124+
val = getenv("ROCBLAS_WORKSPACE_CONFIG");
125+
}
126+
/* 32MiB default, 128MiB for MI300 */
127+
cudaDeviceProp* properties = at::cuda::getCurrentDeviceProperties();
128+
const bool gfx94 = properties != nullptr && properties->major == 9 && properties->minor == 4;
129+
const size_t default_size = gfx94 ? 1024 * 128 * 1024 : 1024 * 32 * 1024;
130+
#else
87131
/* :4096:2:16:8 default, 32MiB for Hopper */
88132
cudaDeviceProp* properties = at::cuda::getCurrentDeviceProperties();
89133
const bool sm90 = properties != nullptr && properties->major == 9 && properties->minor == 0;
90134
const size_t default_size = sm90 ? 4096 * 8 * 1024 : 4096 * 1024 * 2 + 16 * 1024 * 8;
135+
#endif
91136

92137
if (val) {
93138
size_t total_size = 0;
@@ -156,7 +201,6 @@ cublasHandle_t getCurrentCUDABlasHandle() {
156201
auto handle = myPoolWindow->reserve(device);
157202
auto stream = c10::cuda::getCurrentCUDAStream();
158203
TORCH_CUDABLAS_CHECK(cublasSetStream(handle, stream));
159-
#if !defined(USE_ROCM)
160204
// We explicitly set the cublas workspace even though CUDA 12.2+ fixed the
161205
// issue where memory usage increased during graph capture.
162206
// original issue: https://github.com/pytorch/pytorch/pull/83461
@@ -171,6 +215,7 @@ cublasHandle_t getCurrentCUDABlasHandle() {
171215
workspace_it = cublas_handle_stream_to_workspace().insert(workspace_it, {key, getNewWorkspace()});
172216
}
173217
TORCH_CUDABLAS_CHECK(cublasSetWorkspace(handle, workspace_it->second.get(), getChosenWorkspaceSize()));
218+
#if !defined(USE_ROCM)
174219
// On CUDA >= 11, and architecture >= Ampere, cuBLAS can use TF32 to speedup
175220
// FP32 data type calculations based on the value of the allow_tf32 flag.
176221
// To enable TF32, set the math mode of the handle to CUBLAS_TF32_TENSOR_OP_MATH.

docs/source/notes/hip.rst

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,24 @@ complete snapshot of the memory allocator state via
103103
underlying allocation patterns produced by your code.
104104

105105
To debug memory errors, set
106-
``PYTORCH_NO_CUDA_MEMORY_CACHING=1`` in your environment to disable caching.
106+
``PYTORCH_NO_HIP_MEMORY_CACHING=1`` in your environment to disable caching.
107+
``PYTORCH_NO_CUDA_MEMORY_CACHING=1`` is also accepted for ease of porting.
108+
109+
.. hipblas-workspaces:
110+
111+
hipBLAS workspaces
112+
------------------
113+
114+
For each combination of hipBLAS handle and HIP stream, a hipBLAS workspace will be allocated if that
115+
handle and stream combination executes a hipBLAS kernel that requires a workspace. In order to
116+
avoid repeatedly allocating workspaces, these workspaces are not deallocated unless
117+
``torch._C._cuda_clearCublasWorkspaces()`` is called; note that it's the same function for CUDA or
118+
HIP. The workspace size per allocation can be specified via the environment variable
119+
``HIPBLAS_WORKSPACE_CONFIG`` with the format ``:[SIZE]:[COUNT]``. As an example, the environment
120+
variable ``HIPBLAS_WORKSPACE_CONFIG=:4096:2:16:8`` specifies a total size of ``2 * 4096 + 8 * 16
121+
KiB`` or 8 MIB. The default workspace size is 32 MiB; MI300 and newer defaults to 128 MiB. To force
122+
hipBLAS to avoid using workspaces, set ``HIPBLAS_WORKSPACE_CONFIG=:0:0``. For convenience,
123+
``CUBLAS_WORKSPACE_CONFIG`` is also accepted.
107124

108125
.. _hipfft-plan-cache:
109126

test/test_cuda.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
from torch.testing._internal.autocast_test_lists import AutocastTestLists, TestAutocast
3232
from torch.testing._internal.common_cuda import (
3333
_create_scaling_case,
34-
_get_torch_cuda_version,
3534
TEST_CUDNN,
3635
TEST_MULTIGPU,
3736
)
@@ -63,6 +62,7 @@
6362
parametrize,
6463
run_tests,
6564
serialTest,
65+
setBlasBackendsToDefaultFinally,
6666
skipCUDAMemoryLeakCheckIf,
6767
skipCUDANonDefaultStreamIf,
6868
skipIfRocm,
@@ -417,19 +417,23 @@ def test_serialization_array_with_storage(self):
417417
q_copy[1].fill_(10)
418418
self.assertEqual(q_copy[3], torch.cuda.IntStorage(10).fill_(10))
419419

420-
@unittest.skipIf(
421-
TEST_CUDAMALLOCASYNC or TEST_WITH_ROCM, "temporarily disabled for async"
422-
)
423-
@unittest.skipIf(
424-
_get_torch_cuda_version() >= (12, 2),
425-
"skipped as explicit workspace allocation is removed",
426-
)
420+
@unittest.skipIf(TEST_CUDAMALLOCASYNC, "temporarily disabled for async")
421+
@setBlasBackendsToDefaultFinally
427422
def test_cublas_workspace_explicit_allocation(self):
423+
torch.backends.cuda.preferred_blas_library("cublas")
428424
a = torch.randn(7, 7, device="cuda", requires_grad=False)
429-
default_workspace_size = 4096 * 2 * 1024 + 16 * 8 * 1024 # :4096:2:16:8
430-
# different size (32 MiB) expected on Hopper GPU
431-
if torch.cuda.get_device_capability() == (9, 0):
432-
default_workspace_size = 4096 * 8 * 1024
425+
if torch.version.hip:
426+
default_workspace_size = 1024 * 32 * 1024 # :1024:32 32MiB
427+
# different size (128 MiB) expected on MI300 GPU
428+
if torch.cuda.get_device_capability() >= (9, 4):
429+
default_workspace_size = 1024 * 128 * 1024 # :1024:128
430+
else:
431+
default_workspace_size = (
432+
4096 * 2 * 1024 + 16 * 8 * 1024
433+
) # :4096:2:16:8 8MiB
434+
# different size (32 MiB) expected on Hopper GPU
435+
if torch.cuda.get_device_capability() == (9, 0):
436+
default_workspace_size = 4096 * 8 * 1024
433437

434438
def check_workspace_size(inp):
435439
torch._C._cuda_clearCublasWorkspaces()
@@ -1919,7 +1923,9 @@ def test_graph_capture_oom(self):
19191923
not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
19201924
)
19211925
@serialTest()
1926+
@setBlasBackendsToDefaultFinally
19221927
def test_repeat_graph_capture_cublas_workspace_memory(self):
1928+
torch.backends.cuda.preferred_blas_library("cublas")
19231929
(x, y, z) = 1024, 512, 64
19241930
a = torch.rand((x, y), device="cuda")
19251931
b = torch.rand((y, z), device="cuda")

torch/utils/hipify/cuda_to_hip_mappings.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6693,6 +6693,7 @@
66936693
"cublasGetVersion_v2",
66946694
("hipblasGetVersion_v2", CONV_MATH_FUNC, API_BLAS, HIP_UNSUPPORTED),
66956695
),
6696+
("cublasSetWorkspace", ("hipblasSetWorkspace", CONV_MATH_FUNC, API_BLAS)),
66966697
("cublasSetStream", ("hipblasSetStream", CONV_MATH_FUNC, API_BLAS)),
66976698
("cublasGetStream", ("hipblasGetStream", CONV_MATH_FUNC, API_BLAS)),
66986699
("cublasSetStream_v2", ("hipblasSetStream_v2", CONV_MATH_FUNC, API_BLAS)),

0 commit comments

Comments
 (0)