Skip to content

Commit 06e2877

Browse files
committed
Update on "Improve assert failure message in test_get_torch_func_signature_exhaustive"
cc mruberry [ghstack-poisoned]
2 parents f0cc7a0 + c03cae3 commit 06e2877

File tree

113 files changed

+5661
-1698
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

113 files changed

+5661
-1698
lines changed

.circleci/config.yml

Lines changed: 4 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

.circleci/verbatim-sources/job-specs/job-specs-custom.yml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@
4343
set -ex
4444
export COMMIT_DOCKER_IMAGE=${DOCKER_IMAGE}:build-${DOCKER_TAG}-${CIRCLE_SHA1}
4545
echo "DOCKER_IMAGE: "${COMMIT_DOCKER_IMAGE}
46-
tag=${CIRCLE_TAG:1:5}
46+
# turn v1.12.0rc3 into 1.12.0
47+
tag=$(echo $CIRCLE_TAG | sed -e 's/v*\([0-9.]*\).*/\1/')
4748
target=${tag:-master}
4849
echo "building for ${target}"
4950
time docker pull ${COMMIT_DOCKER_IMAGE} >/dev/null
@@ -88,6 +89,8 @@
8889
set -ex
8990
export COMMIT_DOCKER_IMAGE=${DOCKER_IMAGE}:build-${DOCKER_TAG}-${CIRCLE_SHA1}
9091
echo "DOCKER_IMAGE: "${COMMIT_DOCKER_IMAGE}
92+
# turn v1.12.0rc3 into 1.12.0
93+
tag=$(echo $CIRCLE_TAG | sed -e 's/v*\([0-9.]*\).*/\1/')
9194
tag=${CIRCLE_TAG:1:5}
9295
target=${tag:-master}
9396
echo "building for ${target}"

aten/src/ATen/TensorIterator.cpp

Lines changed: 40 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,19 @@ TensorIteratorConfig& TensorIteratorConfig::add_borrowed_input(const TensorBase&
128128

129129
TensorIteratorConfig& TensorIteratorConfig::declare_static_dtype_and_device(ScalarType dtype, Device device) {
130130
TORCH_CHECK(!check_all_same_dtype_, "check_all_same_dtype(false) must be called before declare_static_dtype(...)");
131-
static_dtype_and_device_ = c10::make_optional(std::make_pair(dtype, device));
131+
static_dtype_ = dtype;
132+
static_device_ = device;
133+
return *this;
134+
}
135+
136+
TensorIteratorConfig& TensorIteratorConfig::declare_static_dtype(ScalarType dtype) {
137+
TORCH_CHECK(!check_all_same_dtype_, "check_all_same_dtype(false) must be called before declare_static_dtype(...)");
138+
static_dtype_ = dtype;
139+
return *this;
140+
}
141+
142+
TensorIteratorConfig& TensorIteratorConfig::declare_static_device(Device device) {
143+
static_device_ = device;
132144
return *this;
133145
}
134146

@@ -327,12 +339,20 @@ void TensorIteratorBase::compute_types(const TensorIteratorConfig& config) {
327339
// the device it should be allocated on.
328340
if (!op.is_type_defined()) {
329341
TORCH_INTERNAL_ASSERT(op.is_output, "Found type undefined input tensor!");
330-
if (config.static_dtype_and_device_.has_value()) {
331-
op.target_dtype = config.static_dtype_and_device_->first;
332-
op.device = config.static_dtype_and_device_->second;
342+
343+
if (config.static_dtype_.has_value()) {
344+
op.target_dtype = config.static_dtype_.value();
333345
} else {
334-
TORCH_INTERNAL_ASSERT(config.check_all_same_device_);
335346
has_undefined_outputs = true;
347+
}
348+
349+
if (config.static_device_.has_value()) {
350+
op.device = config.static_device_.value();
351+
} else {
352+
TORCH_INTERNAL_ASSERT(config.check_all_same_device_);
353+
}
354+
355+
if (has_undefined_outputs || !op.device.has_value()) {
336356
continue;
337357
}
338358
}
@@ -418,12 +438,21 @@ void TensorIteratorBase::compute_types(const TensorIteratorConfig& config) {
418438
// - checks that all tensors are on the same device, if requested
419439
// - checks that the common dtype can safely cast to each output, if requested
420440
// - creates temporaries for CPU operations, if needed and requested
441+
common_device_ = common_device;
421442
int max_cpu_scalars_on_non_cpu = config.allow_cpu_scalars_ ? 1 : 0;
422443
int current_cpu_scalars_on_non_cpu = 0;
423444
for (auto& op : operands_) {
424-
if (!op.is_type_defined()) {
445+
bool is_type_defined = op.is_type_defined();
446+
bool is_device_defined = op.is_device_defined();
447+
448+
if (!is_type_defined) {
425449
op.target_dtype = common_dtype_;
450+
}
451+
if (!is_device_defined) {
426452
op.device = common_device;
453+
}
454+
455+
if (!is_type_defined && !is_device_defined) {
427456
continue;
428457
}
429458

@@ -441,10 +470,10 @@ void TensorIteratorBase::compute_types(const TensorIteratorConfig& config) {
441470
TORCH_CHECK(current_cpu_scalars_on_non_cpu < max_cpu_scalars_on_non_cpu,
442471
"Trying to pass too many CPU scalars to non-CPU kernel!");
443472
++current_cpu_scalars_on_non_cpu;
444-
} else if (op.device != common_device) {
473+
} else if (op.device.value() != common_device) {
445474
TORCH_CHECK(false,
446475
"Expected all tensors to be on the same device, but "
447-
"found at least two devices, ", common_device, " and ", op.device, "!");
476+
"found at least two devices, ", common_device, " and ", op.device.value(), "!");
448477
}
449478
}
450479

@@ -490,7 +519,6 @@ void TensorIteratorBase::compute_types(const TensorIteratorConfig& config) {
490519
op.target_dtype = common_dtype_;
491520
}
492521
}
493-
common_device_ = common_device;
494522
}
495523
}
496524

@@ -864,7 +892,7 @@ void TensorIteratorBase::build_comparison_op(
864892
// want the output to be bool. Otherwise (e.g. 'torch.eq(a, b, out=c)') we
865893
// don't coerce the output.
866894
if (!out.defined()) {
867-
config.declare_static_dtype_and_device(kBool, a.device());
895+
config.declare_static_dtype(kBool);
868896
}
869897

870898
// Note [special-case bool outputs]
@@ -943,7 +971,8 @@ void TensorIteratorBase::build_unary_force_boolean_op(const TensorBase& out, con
943971
build(TensorIteratorConfig()
944972
.set_check_mem_overlap(true)
945973
.check_all_same_dtype(false)
946-
.declare_static_dtype_and_device(at::kBool, a.device())
974+
.declare_static_dtype(at::kBool)
975+
.declare_static_device(a.device())
947976
.add_owned_output(out)
948977
.add_owned_input(a));
949978
}

aten/src/ATen/TensorIterator.h

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -122,13 +122,14 @@ struct TORCH_API OperandInfo {
122122
/// but during type promotion target_dtype value can become different from tensor's dtype
123123
/// also, during type promotion target_dtype and device can be set for an undefined tensor so that tensor can be properly
124124
/// constructed later.
125-
Device device = kCPU;
125+
c10::optional<Device> device = c10::nullopt;
126126
ScalarType target_dtype = ScalarType::Undefined;
127127
// Caches dtype of the tensor, because scalar_type is an expensive operation
128128
// If dtype of the tensor is changed (e.g. as a result of type promotion or in allocate_outputs), this
129129
//value should be changed too.
130130
ScalarType current_dtype = ScalarType::Undefined;
131131

132+
bool is_device_defined() const { return device.has_value(); }
132133
bool is_type_defined() const { return target_dtype != ScalarType::Undefined; }
133134
TensorOptions options() const {
134135
return TensorOptions(target_dtype).device(device);
@@ -256,7 +257,7 @@ struct TORCH_API TensorIteratorBase : public impl::MetaBase {
256257
return common_dtype_;
257258
}
258259
ScalarType input_dtype(int arg=0) const { return operands_[num_outputs_ + arg].current_dtype; }
259-
Device device(int arg=0) const { return operands_[arg].device; }
260+
Device device(int arg=0) const { return operands_[arg].device.value(); }
260261
DeviceType device_type(int arg=0) const { return device(arg).type(); }
261262
int64_t element_size(int arg) const { return elementSize(dtype(arg)); }
262263
bool is_scalar(int arg) const;
@@ -725,6 +726,8 @@ class TORCH_API TensorIteratorConfig final {
725726

726727
// Bypass output dtype/device computation and fix the dtype/device as specified here.
727728
TensorIteratorConfig& declare_static_dtype_and_device(ScalarType dtype, Device device);
729+
TensorIteratorConfig& declare_static_dtype(ScalarType dtype);
730+
TensorIteratorConfig& declare_static_device(Device device);
728731
TensorIteratorConfig& declare_static_shape(IntArrayRef shape);
729732
TensorIteratorConfig& declare_static_shape(IntArrayRef shape, IntArrayRef squash_dims);
730733

@@ -742,7 +745,8 @@ class TORCH_API TensorIteratorConfig final {
742745
int num_inputs_ = 0;
743746

744747
c10::optional<DimVector> static_shape_ = c10::nullopt;
745-
c10::optional<std::pair<ScalarType, Device>> static_dtype_and_device_ = c10::nullopt;
748+
c10::optional<ScalarType> static_dtype_ = c10::nullopt;
749+
c10::optional<Device> static_device_ = c10::nullopt;
746750
bool check_mem_overlap_ = true;
747751
bool allow_cpu_scalars_ = false;
748752
bool is_reduction_ = false;

aten/src/ATen/native/BinaryOps.cpp

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -204,13 +204,6 @@ void comparison_op_check(const Tensor& self, const Tensor& other, const Tensor&
204204
native::check_convert(self.item(), other.scalar_type());
205205
}
206206
}
207-
// In-place operation To avoid overflow during type promotion we will check that
208-
// both dtypes of self and other are same
209-
if (result.is_same(self)) {
210-
TORCH_CHECK(self.dtype() == other.dtype(),
211-
"Expected object of scalar type ", self.dtype(), " but got scalar type ",
212-
other.dtype(), " for argument 'other'");
213-
}
214207
}
215208

216209
#define CREATE_COMPARISON_SCALAR_TENSOR_META_FUNC(func) \
@@ -915,9 +908,6 @@ Tensor comparison_op(const Tensor& self, const Tensor& other, OutImpl& out_impl)
915908
// To avoid overflow during type promotion we will check that both dtypes of self and other are same
916909
template <typename OutImpl>
917910
Tensor& comparison_op_(Tensor& self, const Tensor& other, OutImpl& out_impl) {
918-
TORCH_CHECK(self.dtype() == other.dtype(),
919-
"Expected object of scalar type ", self.dtype(), " but got scalar type ",
920-
other.dtype(), " for argument 'other'");
921911
return out_impl(self, self, other);
922912
}
923913

aten/src/ATen/native/Bucketization.cpp

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -76,14 +76,26 @@ void searchsorted_cpu_contiguous(Tensor& result, const Tensor& input, const Tens
7676

7777
void dispatch(Tensor& result, const Tensor& input, const Tensor& boundaries, bool out_int32, bool right) {
7878
if (!out_int32) {
79-
AT_DISPATCH_ALL_TYPES_AND(ScalarType::BFloat16, input.scalar_type(), "searchsorted_out_cpu", [&] {
80-
searchsorted_cpu_contiguous<scalar_t, int64_t>(result, input, boundaries, right);
81-
});
79+
AT_DISPATCH_ALL_TYPES_AND2(
80+
ScalarType::Half,
81+
ScalarType::BFloat16,
82+
input.scalar_type(),
83+
"searchsorted_out_cpu",
84+
[&] {
85+
searchsorted_cpu_contiguous<scalar_t, int64_t>(
86+
result, input, boundaries, right);
87+
});
8288
}
8389
else {
84-
AT_DISPATCH_ALL_TYPES_AND(ScalarType::BFloat16, input.scalar_type(), "searchsorted_out_cpu", [&] {
85-
searchsorted_cpu_contiguous<scalar_t, int>(result, input, boundaries, right);
86-
});
90+
AT_DISPATCH_ALL_TYPES_AND2(
91+
ScalarType::Half,
92+
ScalarType::BFloat16,
93+
input.scalar_type(),
94+
"searchsorted_out_cpu",
95+
[&] {
96+
searchsorted_cpu_contiguous<scalar_t, int>(
97+
result, input, boundaries, right);
98+
});
8799
}
88100
}
89101

aten/src/ATen/native/Convolution.cpp

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -432,6 +432,42 @@ bool check_cudnn_depthwise_workload(const at::Tensor& input, int stride) {
432432
}
433433
return false;
434434
}
435+
436+
// simplified version for cudnn 8.2 and above
437+
bool check_cudnn_depthwise_workload_with_filter(const at::Tensor& input, int stride, const at::Tensor& weight) {
438+
// 1D conv
439+
if(input.size(2) == 1 && stride == 1){
440+
return true;
441+
}
442+
443+
// 2d conv
444+
// only square filters
445+
if (weight.size(2) != weight.size(3)) return false;
446+
int filter = weight.size(3);
447+
// only 1/3/5 filter
448+
if (filter != 1 && filter != 3 && filter != 5) return false;
449+
// we don't enforce square input but only check width to reduce heuristic space
450+
if (input.size(3) < 7) return false; // min width 7
451+
int w = input.size(3);
452+
// only 1/2 stride, use cudnn for all stride 1
453+
if (stride == 1) return true;
454+
if (stride != 2) return false;
455+
456+
int ch = input.size(1);
457+
int bs = input.size(0);
458+
// special case since bs1 show good perf in lots of cases
459+
if (bs == 1) {
460+
if (filter == 1 && w <= 28) return true;
461+
if (filter == 3 || filter == 5) return true;
462+
} else {
463+
if (filter == 1 && bs <= 16 && ch >= 128 && w <= 7) return true;
464+
if (filter == 3 || filter == 5) {
465+
if ((ch >= 512) || (ch >= 256 && w >= 28)) return true;
466+
}
467+
}
468+
return false;
469+
}
470+
435471
// Use cudnn for FP16 depthwise convolutions
436472
auto ConvParams::use_cudnn_depthwise(
437473
const at::Tensor& input, const at::Tensor& weight) const -> bool {
@@ -440,6 +476,20 @@ auto ConvParams::use_cudnn_depthwise(
440476
}
441477
if (detail::getCUDAHooks().supportsDepthwiseConvolutionWithCuDNN()) {
442478
long cudnn_version = detail::getCUDAHooks().versionCuDNN();
479+
if (cudnn_version >= 8200) {
480+
bool kernel_cond = (use_cudnn(input, weight) &&
481+
input.scalar_type() == kHalf && // only for FP16
482+
weight.scalar_type() == kHalf &&
483+
is_depthwise(input, weight) &&
484+
input.ndimension() == 4 && // TODO: 5-D contiguous depthwise is not supported yet, need benchmarks
485+
!is_dilated() && // no dilation supported
486+
(stride[0] == stride[1] || input.size(2) == 1) && // square or 1d
487+
input.size(1) >= 32); // min 32 channels supported)
488+
if (kernel_cond) {
489+
return check_cudnn_depthwise_workload_with_filter(input, stride[1], weight);
490+
}
491+
}
492+
// keep (7600 <= cudnn < 8200) code unchanged
443493
bool kernel_cond = (cudnn_version >= 7600 &&
444494
use_cudnn(input, weight) &&
445495
input.scalar_type() == kHalf && // only for FP16

aten/src/ATen/native/cuda/Bucketization.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,12 +92,12 @@ void searchsorted_cuda_contiguous(Tensor& result, const Tensor& input, const Ten
9292

9393
void dispatch(Tensor& result, const Tensor& input, const Tensor& boundaries, bool out_int32, bool right) {
9494
if (!out_int32) {
95-
AT_DISPATCH_ALL_TYPES(input.scalar_type(), "searchsorted_out_cuda", [&] {
95+
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, input.scalar_type(), "searchsorted_out_cuda", [&] {
9696
searchsorted_cuda_contiguous<scalar_t, int64_t>(result, input, boundaries, right);
9797
});
9898
}
9999
else {
100-
AT_DISPATCH_ALL_TYPES(input.scalar_type(), "searchsorted_out_cuda", [&] {
100+
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, input.scalar_type(), "searchsorted_out_cuda", [&] {
101101
searchsorted_cuda_contiguous<scalar_t, int>(result, input, boundaries, right);
102102
});
103103
}

0 commit comments

Comments
 (0)