Skip to content

Commit 3568201

Browse files
committed
Update on "[reland] Make grad point to bucket buffer in DDP to save memory usage"
reland #41954 Add one argument in DDP API to enable/disable letting grads pointing to views. When it is disabled, behavior is the same as DDP right now; when it is enabled, Make both variable.grad() and grad in distautograd context point to bucket buffer in DDP to save memory usage. In this case, grad will be view of bucket buffer tensors, in order to make it compatiable with optimizer.zero_grad(), we made changes in #41283. Also be noted that we can not make variable.grad() pointing to bucket buffer during construction time, because we want to keep grad undefined for unused parameters. Differential Revision: [D23588186](https://our.internmc.facebook.com/intern/diff/D23588186/) [ghstack-poisoned]
2 parents 6ff00f2 + f3cce29 commit 3568201

File tree

83 files changed

+3166
-978
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

83 files changed

+3166
-978
lines changed

.circleci/cimodel/data/binary_build_data.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def get_processor_arch_name(gpu_version):
5454
)),
5555
# Skip CUDA-9.2 builds on Windows
5656
windows=(
57-
[v for v in dimensions.GPU_VERSIONS if v not in ['cuda92', "rocm3.7"]],
57+
[v for v in dimensions.GPU_VERSIONS if v not in ['cuda92'] + dimensions.ROCM_VERSION_LABELS],
5858
OrderedDict(
5959
wheel=dimensions.STANDARD_PYTHON_VERSIONS,
6060
conda=dimensions.STANDARD_PYTHON_VERSIONS,
@@ -142,11 +142,11 @@ def get_children(self):
142142

143143
# XXX disabling conda rocm build since docker images are not there
144144
if self.find_prop("package_format") == 'conda':
145-
gpu_versions = filter(lambda x: x != "rocm3.7", gpu_versions)
145+
gpu_versions = filter(lambda x: x not in dimensions.ROCM_VERSION_LABELS, gpu_versions)
146146

147147
# XXX libtorch rocm build is temporarily disabled
148148
if self.find_prop("package_format") == 'libtorch':
149-
gpu_versions = filter(lambda x: x != "rocm3.7", gpu_versions)
149+
gpu_versions = filter(lambda x: x not in dimensions.ROCM_VERSION_LABELS, gpu_versions)
150150

151151
return [ArchConfigNode(self, v) for v in gpu_versions]
152152

.circleci/cimodel/data/dimensions.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,12 @@
99

1010
ROCM_VERSIONS = [
1111
"3.7",
12+
"3.8",
1213
]
1314

14-
GPU_VERSIONS = [None] + ["cuda" + v for v in CUDA_VERSIONS] + ["rocm" + v for v in ROCM_VERSIONS]
15+
ROCM_VERSION_LABELS = ["rocm" + v for v in ROCM_VERSIONS]
16+
17+
GPU_VERSIONS = [None] + ["cuda" + v for v in CUDA_VERSIONS] + ROCM_VERSION_LABELS
1518

1619
STANDARD_PYTHON_VERSIONS = [
1720
"3.6",

.circleci/cimodel/data/simple/docker_definitions.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
"pytorch-linux-xenial-py3.6-gcc7.2",
2929
"pytorch-linux-xenial-py3.6-gcc7",
3030
"pytorch-linux-bionic-rocm3.7-py3.6",
31+
"pytorch-linux-bionic-rocm3.8-py3.6",
3132
]
3233

3334

.circleci/config.yml

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2130,6 +2130,39 @@ workflows:
21302130
only:
21312131
- /v[0-9]+(\.[0-9]+)*-rc[0-9]+/
21322132
docker_image: "pytorch/manylinux-rocm:3.7"
2133+
- binary_linux_build:
2134+
name: binary_linux_manywheel_3_6m_rocm3_8_devtoolset7_nightly_build
2135+
build_environment: "manywheel 3.6m rocm3.8 devtoolset7"
2136+
filters:
2137+
branches:
2138+
only:
2139+
- /.*/
2140+
tags:
2141+
only:
2142+
- /v[0-9]+(\.[0-9]+)*-rc[0-9]+/
2143+
docker_image: "pytorch/manylinux-rocm:3.8"
2144+
- binary_linux_build:
2145+
name: binary_linux_manywheel_3_7m_rocm3_8_devtoolset7_nightly_build
2146+
build_environment: "manywheel 3.7m rocm3.8 devtoolset7"
2147+
filters:
2148+
branches:
2149+
only:
2150+
- /.*/
2151+
tags:
2152+
only:
2153+
- /v[0-9]+(\.[0-9]+)*-rc[0-9]+/
2154+
docker_image: "pytorch/manylinux-rocm:3.8"
2155+
- binary_linux_build:
2156+
name: binary_linux_manywheel_3_8m_rocm3_8_devtoolset7_nightly_build
2157+
build_environment: "manywheel 3.8m rocm3.8 devtoolset7"
2158+
filters:
2159+
branches:
2160+
only:
2161+
- /.*/
2162+
tags:
2163+
only:
2164+
- /v[0-9]+(\.[0-9]+)*-rc[0-9]+/
2165+
docker_image: "pytorch/manylinux-rocm:3.8"
21332166
- binary_linux_build:
21342167
name: binary_linux_conda_3_6_cpu_devtoolset7_nightly_build
21352168
build_environment: "conda 3.6 cpu devtoolset7"
@@ -3429,6 +3462,51 @@ workflows:
34293462
docker_image: "pytorch/manylinux-rocm:3.7"
34303463
use_cuda_docker_runtime: "1"
34313464
resource_class: gpu.medium
3465+
- binary_linux_test:
3466+
name: binary_linux_manywheel_3_6m_rocm3_8_devtoolset7_nightly_test
3467+
build_environment: "manywheel 3.6m rocm3.8 devtoolset7"
3468+
filters:
3469+
branches:
3470+
only:
3471+
- /.*/
3472+
tags:
3473+
only:
3474+
- /v[0-9]+(\.[0-9]+)*-rc[0-9]+/
3475+
requires:
3476+
- binary_linux_manywheel_3_6m_rocm3_8_devtoolset7_nightly_build
3477+
docker_image: "pytorch/manylinux-rocm:3.8"
3478+
use_cuda_docker_runtime: "1"
3479+
resource_class: gpu.medium
3480+
- binary_linux_test:
3481+
name: binary_linux_manywheel_3_7m_rocm3_8_devtoolset7_nightly_test
3482+
build_environment: "manywheel 3.7m rocm3.8 devtoolset7"
3483+
filters:
3484+
branches:
3485+
only:
3486+
- /.*/
3487+
tags:
3488+
only:
3489+
- /v[0-9]+(\.[0-9]+)*-rc[0-9]+/
3490+
requires:
3491+
- binary_linux_manywheel_3_7m_rocm3_8_devtoolset7_nightly_build
3492+
docker_image: "pytorch/manylinux-rocm:3.8"
3493+
use_cuda_docker_runtime: "1"
3494+
resource_class: gpu.medium
3495+
- binary_linux_test:
3496+
name: binary_linux_manywheel_3_8m_rocm3_8_devtoolset7_nightly_test
3497+
build_environment: "manywheel 3.8m rocm3.8 devtoolset7"
3498+
filters:
3499+
branches:
3500+
only:
3501+
- /.*/
3502+
tags:
3503+
only:
3504+
- /v[0-9]+(\.[0-9]+)*-rc[0-9]+/
3505+
requires:
3506+
- binary_linux_manywheel_3_8m_rocm3_8_devtoolset7_nightly_build
3507+
docker_image: "pytorch/manylinux-rocm:3.8"
3508+
use_cuda_docker_runtime: "1"
3509+
resource_class: gpu.medium
34323510
- binary_linux_test:
34333511
name: binary_linux_conda_3_6_cpu_devtoolset7_nightly_test
34343512
build_environment: "conda 3.6 cpu devtoolset7"
@@ -4932,6 +5010,48 @@ workflows:
49325010
- /v[0-9]+(\.[0-9]+)*-rc[0-9]+/
49335011
package_type: manywheel
49345012
upload_subfolder: rocm3.7
5013+
- binary_upload:
5014+
name: binary_linux_manywheel_3_6m_rocm3_8_devtoolset7_nightly_upload
5015+
context: org-member
5016+
requires:
5017+
- binary_linux_manywheel_3_6m_rocm3_8_devtoolset7_nightly_test
5018+
filters:
5019+
branches:
5020+
only:
5021+
- nightly
5022+
tags:
5023+
only:
5024+
- /v[0-9]+(\.[0-9]+)*-rc[0-9]+/
5025+
package_type: manywheel
5026+
upload_subfolder: rocm3.8
5027+
- binary_upload:
5028+
name: binary_linux_manywheel_3_7m_rocm3_8_devtoolset7_nightly_upload
5029+
context: org-member
5030+
requires:
5031+
- binary_linux_manywheel_3_7m_rocm3_8_devtoolset7_nightly_test
5032+
filters:
5033+
branches:
5034+
only:
5035+
- nightly
5036+
tags:
5037+
only:
5038+
- /v[0-9]+(\.[0-9]+)*-rc[0-9]+/
5039+
package_type: manywheel
5040+
upload_subfolder: rocm3.8
5041+
- binary_upload:
5042+
name: binary_linux_manywheel_3_8m_rocm3_8_devtoolset7_nightly_upload
5043+
context: org-member
5044+
requires:
5045+
- binary_linux_manywheel_3_8m_rocm3_8_devtoolset7_nightly_test
5046+
filters:
5047+
branches:
5048+
only:
5049+
- nightly
5050+
tags:
5051+
only:
5052+
- /v[0-9]+(\.[0-9]+)*-rc[0-9]+/
5053+
package_type: manywheel
5054+
upload_subfolder: rocm3.8
49355055
- binary_upload:
49365056
name: binary_linux_conda_3_6_cpu_devtoolset7_nightly_upload
49375057
context: org-member
@@ -6320,6 +6440,9 @@ workflows:
63206440
- docker_build_job:
63216441
name: "docker-pytorch-linux-bionic-rocm3.7-py3.6"
63226442
image_name: "pytorch-linux-bionic-rocm3.7-py3.6"
6443+
- docker_build_job:
6444+
name: "docker-pytorch-linux-bionic-rocm3.8-py3.6"
6445+
image_name: "pytorch-linux-bionic-rocm3.8-py3.6"
63236446
- pytorch_linux_build:
63246447
name: pytorch_linux_xenial_py3_6_gcc5_4_build
63256448
requires:
@@ -7455,6 +7578,42 @@ workflows:
74557578
docker_image: "pytorch/manylinux-rocm:3.7"
74567579
use_cuda_docker_runtime: "1"
74577580
resource_class: gpu.medium
7581+
- smoke_linux_test:
7582+
name: smoke_linux_manywheel_3_6m_rocm3_8_devtoolset7_nightly
7583+
build_environment: "manywheel 3.6m rocm3.8 devtoolset7"
7584+
requires:
7585+
- update_s3_htmls
7586+
filters:
7587+
branches:
7588+
only:
7589+
- postnightly
7590+
docker_image: "pytorch/manylinux-rocm:3.8"
7591+
use_cuda_docker_runtime: "1"
7592+
resource_class: gpu.medium
7593+
- smoke_linux_test:
7594+
name: smoke_linux_manywheel_3_7m_rocm3_8_devtoolset7_nightly
7595+
build_environment: "manywheel 3.7m rocm3.8 devtoolset7"
7596+
requires:
7597+
- update_s3_htmls
7598+
filters:
7599+
branches:
7600+
only:
7601+
- postnightly
7602+
docker_image: "pytorch/manylinux-rocm:3.8"
7603+
use_cuda_docker_runtime: "1"
7604+
resource_class: gpu.medium
7605+
- smoke_linux_test:
7606+
name: smoke_linux_manywheel_3_8m_rocm3_8_devtoolset7_nightly
7607+
build_environment: "manywheel 3.8m rocm3.8 devtoolset7"
7608+
requires:
7609+
- update_s3_htmls
7610+
filters:
7611+
branches:
7612+
only:
7613+
- postnightly
7614+
docker_image: "pytorch/manylinux-rocm:3.8"
7615+
use_cuda_docker_runtime: "1"
7616+
resource_class: gpu.medium
74587617
- smoke_linux_test:
74597618
name: smoke_linux_conda_3_6_cpu_devtoolset7_nightly
74607619
build_environment: "conda 3.6 cpu devtoolset7"

.circleci/docker/build.sh

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,13 @@ case "$image" in
262262
VISION=yes
263263
ROCM_VERSION=3.7
264264
;;
265+
pytorch-linux-bionic-rocm3.8-py3.6)
266+
ANACONDA_PYTHON_VERSION=3.6
267+
PROTOBUF=yes
268+
DB=yes
269+
VISION=yes
270+
ROCM_VERSION=3.8
271+
;;
265272
*)
266273
# Catch-all for builds that are not hardcoded.
267274
PROTOBUF=yes

.circleci/docker/common/install_base.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ esac
118118

119119
# Install Valgrind separately since the apt-get version is too old.
120120
mkdir valgrind_build && cd valgrind_build
121-
VALGRIND_VERSION=3.15.0
121+
VALGRIND_VERSION=3.16.1
122122
if ! wget http://valgrind.org/downloads/valgrind-${VALGRIND_VERSION}.tar.bz2
123123
then
124124
wget https://sourceware.org/ftp/valgrind/valgrind-${VALGRIND_VERSION}.tar.bz2

aten/src/ATen/Context.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,4 +230,27 @@ Allocator* getCPUAllocator() {
230230
return getTHDefaultAllocator();
231231
}
232232

233+
// override_allow_tf32_flag = true
234+
// means the allow_tf32 flags are overrided and tf32 is force disabled
235+
// override_allow_tf32_flag = false
236+
// means the original allow_tf32 flags are followed
237+
thread_local bool override_allow_tf32_flag = false;
238+
239+
NoTF32Guard::NoTF32Guard() {
240+
if (!override_allow_tf32_flag) {
241+
changed = true;
242+
override_allow_tf32_flag = true;
243+
}
244+
}
245+
246+
NoTF32Guard::~NoTF32Guard() {
247+
if (changed) {
248+
override_allow_tf32_flag = false;
249+
}
250+
}
251+
252+
bool NoTF32Guard::should_disable_tf32() {
253+
return override_allow_tf32_flag;
254+
}
255+
233256
} // namespace at

aten/src/ATen/Context.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,4 +327,20 @@ static inline void manual_seed(uint64_t seed) {
327327
}
328328
}
329329

330+
// When the global flag `allow_tf32` is set to true, cuBLAS handles are
331+
// automatically configured to use math mode CUBLAS_TF32_TENSOR_OP_MATH.
332+
// For some operators, such as addmv, TF32 offers no performance improvement
333+
// but causes precision loss. To help this case, this class implements
334+
// a RAII guard that can be used to quickly disable TF32 within its scope.
335+
//
336+
// Usage:
337+
// NoTF32Guard disable_tf32;
338+
struct TORCH_API NoTF32Guard {
339+
NoTF32Guard();
340+
~NoTF32Guard();
341+
static bool should_disable_tf32();
342+
private:
343+
bool changed = false;
344+
};
345+
330346
} // namespace at

aten/src/ATen/WrapDimUtils.h

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,15 @@ static inline int64_t maybe_wrap_dim(int64_t dim, const std::vector<std::vector<
3030
return maybe_wrap_dim(dim, tensor_sizes[0].size());
3131
}
3232

33-
// wrap each of dims basing on dim_post_expr
34-
static inline void maybe_wrap_dims(std::vector<int64_t>& dims, int64_t dim_post_expr) {
33+
// wrap each dim in the dims array, taking dim_post_expr as the true number of dimensions
34+
static inline void maybe_wrap_dims_n(int64_t* dims, int64_t ndims, int64_t dim_post_expr) {
3535
if (dim_post_expr <= 0) {
3636
dim_post_expr = 1; // this will make range [-1, 0]
3737
}
3838
int64_t min = -dim_post_expr;
3939
int64_t max = dim_post_expr - 1;
40-
for (auto& dim : dims) {
40+
for (int64_t i = 0; i < ndims; ++i) {
41+
auto &dim = dims[i];
4142
if (dim < min || dim > max) {
4243
TORCH_CHECK_INDEX(false,
4344
"Dimension out of range (expected to be in range of [",
@@ -47,6 +48,13 @@ static inline void maybe_wrap_dims(std::vector<int64_t>& dims, int64_t dim_post_
4748
}
4849
}
4950

51+
// Wrap each dim in a contiguous container, taking dim_post_expr as the true number of dimensions
52+
// E.g. could also be std::array or c10::SmallVector
53+
template <typename Container>
54+
inline void maybe_wrap_dims(Container& dims, int64_t dim_post_expr) {
55+
return maybe_wrap_dims_n(dims.data(), dims.size(), dim_post_expr);
56+
}
57+
5058
// previously, size [0] tensors were the only possible empty tensors; thus, it wasn't possible
5159
// to cat empty tensors unless all the other tensors were 1-dimensional, so we allowed these tensors
5260
// to be "skipped" (both for wrap dimension behavior and dimension size checking).

aten/src/ATen/cuda/CUDABlas.cpp

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -407,19 +407,22 @@ void gemm<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) {
407407
#endif
408408

409409
#if !defined(__HIP_PLATFORM_HCC__) || (defined(__HIP_PLATFORM_HCC__) && HIP_VERSION >= 210)
410-
template <>
411-
void gemv<c10::complex<float>>(CUDABLAS_GEMV_ARGTYPES(c10::complex<float>)) {
412-
// See Note [Writing Nondeterministic Operations]
413-
globalContext().alertCuBLASConfigNotDeterministic();
414-
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
415-
cublasOperation_t op = _cublasOpFromChar(trans);
416-
_cublasAdjustLdLevel2(m, n, &lda);
417-
GEMV_CHECK_ARGVALUES(c10::complex<float>);
418-
TORCH_CUDABLAS_CHECK(
419-
cublasCgemv(handle, op, m, n, reinterpret_cast<const cuComplex*>(&alpha), reinterpret_cast<const cuComplex*>(a),
420-
lda, reinterpret_cast<const cuComplex*>(x), incx, reinterpret_cast<const cuComplex*>(&beta),
421-
reinterpret_cast<cuComplex*>(y), incy));
422-
}
410+
template <>
411+
void gemv<c10::complex<float>>(CUDABLAS_GEMV_ARGTYPES(c10::complex<float>)) {
412+
// gemv is bw bound, and does not benefit from TF32. But the precision
413+
// loss still happens on TF32. So we disable it here.
414+
NoTF32Guard disable_tf32;
415+
// See Note [Writing Nondeterministic Operations]
416+
globalContext().alertCuBLASConfigNotDeterministic();
417+
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
418+
cublasOperation_t op = _cublasOpFromChar(trans);
419+
_cublasAdjustLdLevel2(m, n, &lda);
420+
GEMV_CHECK_ARGVALUES(c10::complex<float>);
421+
TORCH_CUDABLAS_CHECK(
422+
cublasCgemv(handle, op, m, n, reinterpret_cast<const cuComplex*>(&alpha), reinterpret_cast<const cuComplex*>(a),
423+
lda, reinterpret_cast<const cuComplex*>(x), incx, reinterpret_cast<const cuComplex*>(&beta),
424+
reinterpret_cast<cuComplex*>(y), incy));
425+
}
423426
#endif
424427

425428
template <>
@@ -436,6 +439,9 @@ void gemv<double>(CUDABLAS_GEMV_ARGTYPES(double)) {
436439

437440
template <>
438441
void gemv<float>(CUDABLAS_GEMV_ARGTYPES(float)) {
442+
// gemv is bw bound, and does not benefit from TF32. But the precision
443+
// loss still happens on TF32. So we disable it here.
444+
NoTF32Guard disable_tf32;
439445
// See Note [Writing Nondeterministic Operations]
440446
globalContext().alertCuBLASConfigNotDeterministic();
441447
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();

0 commit comments

Comments
 (0)