Skip to content

Commit aedbe33

Browse files
committed
Update on "Preserve reshapes in AOTAutograd"
This commit adds a configuration knob to AOTAutograd to make it transform reshape calls into reshape_copy by default, making the reshape work even if the striding of tensors changes on subsequent runs. This is not sound if the input/output get modified after this change, so for safety, we use the new functionalization "freeze storage" feature to detect if this case happened. TODO: - Plumb this as a configuration option so backends can pick what they want - Figure out what the fallback strategy should be if the user actually did mutate the input/output of reshape. One possibility is to try tracing again but this time without preserving reshapes. - Teach backends how to compile _reshape_copy Signed-off-by: Edward Z. Yang <ezyangfb.com> [ghstack-poisoned]
2 parents 9693710 + 4a47e01 commit aedbe33

File tree

107 files changed

+2966
-644
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

107 files changed

+2966
-644
lines changed

.github/merge_rules.yaml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -325,10 +325,10 @@
325325
- torch/csrc/lazy/**
326326
- test/cpp/lazy/**
327327
- test/lazy/**
328-
- codegen/api/lazy.py
329-
- codegen/dest/lazy_ir.py
330-
- codegen/dest/lazy_ts_lowering.py
331-
- codegen/gen_lazy_tensor.py
328+
- torchgen/api/lazy.py
329+
- torchgen/dest/lazy_ir.py
330+
- torchgen/dest/lazy_ts_lowering.py
331+
- torchgen/gen_lazy_tensor.py
332332
- aten/src/ATen/native/ts_native_functions.yaml
333333
approved_by:
334334
- alanwaketan

.github/scripts/install_nvidia_utils_linux.sh

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,28 @@ install_nvidia_driver_amzn2() {
5959
sudo yum install -y "kernel-devel-uname-r == $(uname -r)"
6060
sudo modprobe backlight
6161
sudo curl -fsL -o /tmp/nvidia_driver "https://s3.amazonaws.com/ossci-linux/nvidia_driver/$DRIVER_FN"
62-
sudo /bin/bash /tmp/nvidia_driver -s --no-drm || (sudo cat /var/log/nvidia-installer.log && false)
62+
63+
set +e
64+
sudo /bin/bash /tmp/nvidia_driver -s --no-drm
65+
NVIDIA_INSTALLATION_STATUS=$?
66+
67+
if [ "$NVIDIA_INSTALLATION_STATUS" -ne 0 ]; then
68+
sudo cat /var/log/nvidia-installer.log
69+
70+
NVIDIA_DEVICES=$(lspci -D | grep -i NVIDIA | cut -d' ' -f1)
71+
# The GPU can get stuck in a failure state if somehow the test crashs the GPU microcode. When this
72+
# happens, we'll try to reset all NVIDIA devices https://github.com/pytorch/pytorch/issues/88388
73+
for PCI_ID in "$NVIDIA_DEVICES"; do
74+
DEVICE_ENABLED=$(cat /sys/bus/pci/devices/$PCI_ID/enable)
75+
76+
echo "Reseting $PCI_ID (enabled state: $DEVICE_ENABLED)"
77+
echo "1" > /sys/bus/pci/devices/$PCI_ID/reset
78+
sleep 1
79+
done
80+
fi
81+
6382
sudo rm -fv /tmp/nvidia_driver
83+
set -e
6484
fi
6585

6686
sudo modprobe nvidia || true

.jenkins/pytorch/test.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -733,6 +733,7 @@ elif [[ "$TEST_CONFIG" == deploy ]]; then
733733
elif [[ "${TEST_CONFIG}" == *inductor_distributed* ]]; then
734734
install_filelock
735735
install_triton
736+
install_huggingface
736737
test_inductor_distributed
737738
elif [[ "${TEST_CONFIG}" == *dynamo* && "${SHARD_NUMBER}" == 1 && $NUM_TEST_SHARDS -gt 1 ]]; then
738739
test_without_numpy

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,7 @@ if(NOT USE_XNNPACK AND CMAKE_VERSION VERSION_LESS ${XNNPACK_MIN_CMAKE_VER})
285285
endif()
286286
option(USE_ZMQ "Use ZMQ" OFF)
287287
option(USE_ZSTD "Use ZSTD" OFF)
288+
option(TORCH_DISABLE_GPU_ASSERTS "Disable GPU asserts by default" OFF)
288289
# Ensure that an ITT build is the default for x86 CPUs
289290
cmake_dependent_option(
290291
USE_ITT "Use Intel(R) VTune Profiler ITT functionality" ON

Makefile

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,7 @@ lint:
3131

3232
quicklint:
3333
lintrunner
34+
35+
triton:
36+
$(PIP) uninstall -y triton
37+
$(PIP) install -U "git+https://github.com/openai/triton@$(shell cat .github/ci_commit_pins/triton.txt)#subdirectory=python"

aten/src/ATen/native/mkldnn/TensorShape.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2-
#include <ATen/core/Tensor.h>
32
#include <ATen/Config.h>
43
#include <ATen/InferSize.h>
4+
#include <ATen/WrapDimUtils.h>
5+
#include <ATen/core/Tensor.h>
56
#include <c10/core/SymIntArrayRef.h>
67

78
#ifndef AT_PER_OPERATOR_HEADERS
@@ -78,6 +79,9 @@ Tensor mkldnn_clone(const Tensor& self, c10::optional<c10::MemoryFormat> optiona
7879
}
7980

8081
Tensor mkldnn_transpose(const Tensor& self, int64_t dim0, int64_t dim1) {
82+
auto ndims = self.dim();
83+
dim0 = maybe_wrap_dim(dim0, ndims);
84+
dim1 = maybe_wrap_dim(dim1, ndims);
8185
const ideep::tensor& x = itensor_from_mkldnn(self);
8286
ideep::tensor y;
8387
std::vector<int> axes(x.ndims());

aten/src/ATen/native/mps/operations/UnaryOps.mm

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,13 @@
77
#include <ATen/native/mps/OperationUtils.h>
88
#include <torch/library.h>
99

10+
// TODO: Remove me when moved to MacOS 13
11+
@interface MPSGraph (VenturaOps)
12+
- (MPSGraphTensor *)cumulativeSumWithTensor:(MPSGraphTensor *)tensor
13+
axis:(NSInteger)axis
14+
name:(NSString *)name;
15+
@end
16+
1017
namespace at {
1118
namespace native {
1219
namespace mps {
@@ -30,7 +37,7 @@ void unary_op(const Tensor& self, const Tensor& output, std::string op_name, Una
3037
}
3138
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
3239
@autoreleasepool {
33-
string key = op_name + getTensorsStringKey({self}, /*use_scalar_value*/ false);
40+
string key = op_name + getTensorsStringKey({self, output}, /*use_scalar_value*/ false);
3441
auto cachedGraph = cache_->LookUpAs<MPSUnaryCachedGraph>(key);
3542

3643
if(!cachedGraph) {
@@ -263,5 +270,42 @@ void unary_op(const Tensor& self, const Tensor& output, std::string op_name, Una
263270
});
264271
}
265272

273+
274+
static bool mpsSupportsCumsum() {
275+
id mpsCD = NSClassFromString(@"MPSGraph");
276+
return [mpsCD instancesRespondToSelector:@selector(cumulativeSumWithTensor:axis:name:)] == YES;
277+
}
278+
279+
280+
TORCH_IMPL_FUNC(cumsum_out_mps)
281+
(const Tensor& self,
282+
int64_t dim,
283+
c10::optional<ScalarType> dtype,
284+
const Tensor& result) {
285+
TORCH_CHECK(dim >=0 && dim < std::max(1LL, self.ndimension()), "Expected dim to be between 0 and ", self.ndimension(), " but got ", dim);
286+
if (!mpsSupportsCumsum()) {
287+
TORCH_WARN_ONCE("torch.cumsum supported by MPS on MacOS 13+, please upgrade");
288+
auto cpu_result = self.to(at::Device(kCPU)).cumsum(dim, dtype);
289+
at::_copy_from_and_resize(cpu_result, result);
290+
return;
291+
}
292+
auto input = dtype.has_value() ? self.to(dtype.value()) : self;
293+
mps::unary_op(input, result, "cumsum_out_mp" + std::to_string(dim),
294+
^ MPSGraphTensor* (MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
295+
// cumsum is horribly broken for int8, int16 and as chances for overflow is pretty high, cast to int32
296+
if (isIntegralType(input.scalar_type()) && input.scalar_type() !=ScalarType::Int) {
297+
inputTensor = mps::castMPSTensor(mpsGraph, inputTensor, result.scalar_type());
298+
}
299+
auto rc = [mpsGraph cumulativeSumWithTensor: inputTensor
300+
axis: dim
301+
name: nil];
302+
if (result.scalar_type()!= input.scalar_type() ||
303+
(isIntegralType(input.scalar_type()) && input.scalar_type() !=ScalarType::Int)) {
304+
return mps::castMPSTensor(mpsGraph, rc, result.scalar_type());
305+
}
306+
return rc;
307+
});
308+
}
309+
266310
} // namespace native
267311
} // namespace at

aten/src/ATen/native/native_functions.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1776,6 +1776,7 @@
17761776
device_check: NoCheck # TensorIterator
17771777
dispatch:
17781778
CPU, CUDA: cumsum_out
1779+
MPS: cumsum_out_mps
17791780

17801781
- func: cumsum.dimname(Tensor self, Dimname dim, *, ScalarType? dtype=None) -> Tensor
17811782
device_check: NoCheck # TensorIterator
@@ -4247,6 +4248,7 @@
42474248
dispatch:
42484249
SparseCPU, SparseCUDA: neg_sparse
42494250
SparseCsrCPU, SparseCsrCUDA: neg_sparse_csr
4251+
NestedTensorCPU, NestedTensorCUDA: NestedTensor_neg
42504252
tags: canonical
42514253

42524254
- func: neg_(Tensor(a!) self) -> Tensor(a!)
@@ -4256,6 +4258,7 @@
42564258
dispatch:
42574259
SparseCPU, SparseCUDA: neg_sparse_
42584260
SparseCsrCPU, SparseCsrCUDA: neg_sparse_csr_
4261+
NestedTensorCPU, NestedTensorCUDA: NestedTensor_neg_
42594262

42604263
- func: neg.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
42614264
device_check: NoCheck # TensorIterator

aten/src/ATen/native/nested/NestedTensorUnaryOps.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,5 +58,17 @@ Tensor NestedTensor_tanh(const Tensor& self) {
5858
return map_nt(self, at::tanh);
5959
}
6060

61+
Tensor& NestedTensor_neg_(Tensor& self) {
62+
auto self_ptr = get_nested_tensor_impl(self);
63+
check_numel_equals_buffer_size(self_ptr);
64+
auto buffer = self_ptr->get_buffer();
65+
at::neg_(buffer);
66+
return self;
67+
}
68+
69+
Tensor NestedTensor_neg(const Tensor& self) {
70+
return map_nt(self, at::neg);
71+
}
72+
6173
} // namespace native
6274
} // namespace at

aten/src/ATen/native/quantized/cpu/fbgemm_utils.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -560,6 +560,7 @@ int register_embedding_params() {
560560
return PackedEmbeddingBagWeight::prepack(weight);
561561
})
562562
.def("bit_rate", &EmbeddingPackedParamsBase::bit_rate)
563+
.def("unpack", &EmbeddingPackedParamsBase::unpack)
563564
.def("version", &EmbeddingPackedParamsBase::version);
564565

565566
return 0;

0 commit comments

Comments
 (0)