Skip to content

Commit 8c8ecc2

Browse files
committed
Update on "Use TensorOptions for legacy constructors"
This code has traveled a long and winding road (doo doo). First, some back story: originally, all of this code was written in term of Backend, ScalarType and Device, and things were, well, reasonable. But then Backend became DispatchKey, and then it became TensorOptions but someone (ahem, me) was too lazy to migrate the separate ScalarType and Device into the TensorOptions, because it would involve editing a lot of code. Well I have FINALLY made good on the promised check, by absorbing ScalarType and Device into the TensorOptions. We rely heavily on the optional fields in TensorOptions; the idea is that if it is not set, we are in a context where we aren't expecting any particular scalar type / device; otherwise we require them to be set particularly. Signed-off-by: Edward Z. Yang <ezyangfb.com> [ghstack-poisoned]
1 parent f56e275 commit 8c8ecc2

File tree

1 file changed

+22
-14
lines changed

1 file changed

+22
-14
lines changed

torch/csrc/utils/tensor_new.cpp

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,7 @@ Tensor internal_new_from_data(
203203
// It is possible for type_inference to be true and dtype to be present
204204
// in options; in this case the inferred type will take precedence.
205205
bool type_inference,
206+
bool device_inference = false,
206207
bool pin_memory = false) {
207208
if (THPUtils_checkString(data)) {
208209
throw TypeError("new(): invalid data type '%s'", Py_TYPE(data)->tp_name);
@@ -220,7 +221,7 @@ Tensor internal_new_from_data(
220221
// infer the scalar type and device type; it's not expected to infer the layout since these constructors
221222
// are defined per-layout-type (e.g. tensor vs sparse_coo_tensor).
222223
auto inferred_scalar_type = type_inference ? var.scalar_type() : scalar_type;
223-
auto device = options.device_opt().value_or(var.device());
224+
auto device = device_inference ? var.device() : options.device_opt().value_or(var.device());
224225
pybind11::gil_scoped_release no_gil;
225226
maybe_initialize_cuda(device);
226227
return var.to(device, inferred_scalar_type, /*non_blocking=*/false, /*copy=*/copy_variables);
@@ -607,12 +608,12 @@ Tensor sparse_csr_tensor_ctor(PyObject* args, PyObject* kwargs) {
607608
Tensor crow_indices = internal_new_from_data(values.options().dtype(kInt),
608609
r.pyobject(CROW_INDICES_ARG),
609610
/*copy_variables=*/false, /*copy_numpy=*/true,
610-
/*type_inference=*/true);
611+
/*type_inference=*/true, /*device_inference=*/true);
611612
// See Note [Ensuring sparse values and indices match devices]
612613
Tensor col_indices = internal_new_from_data(values.options().dtype(kInt),
613614
r.pyobject(COL_INDICES_ARG),
614615
/*copy_variables=*/false, /*copy_numpy=*/true,
615-
/*type_inference=*/true);
616+
/*type_inference=*/true, /*device_inference=*/true);
616617

617618
return at::sparse_csr_tensor(crow_indices, col_indices, values, r.intlist(SIZE_ARRAY_ARG),
618619
values.options().layout(at::kSparseCsr)).set_requires_grad(r.toBool(REQ_GRAD_ARG));
@@ -629,12 +630,12 @@ Tensor sparse_csr_tensor_ctor(PyObject* args, PyObject* kwargs) {
629630
Tensor crow_indices = internal_new_from_data(values.options().dtype(kInt),
630631
r.pyobject(CROW_INDICES_ARG),
631632
/*copy_variables=*/false, /*copy_numpy=*/true,
632-
/*type_inference=*/true);
633+
/*type_inference=*/true, /*device_inference=*/true);
633634
// See Note [Ensuring sparse values and indices match devices]
634635
Tensor col_indices = internal_new_from_data(values.options().dtype(kInt),
635636
r.pyobject(COL_INDICES_ARG),
636637
/*copy_variables=*/false, /*copy_numpy=*/true,
637-
/*type_inference=*/true);
638+
/*type_inference=*/true, /*device_inference=*/true);
638639
return at::sparse_csr_tensor(crow_indices, col_indices, values,
639640
values.options().layout(at::kSparseCsr)).set_requires_grad(r.toBool(REQ_GRAD_ARG));
640641
}
@@ -668,12 +669,12 @@ Tensor _sparse_csr_tensor_unsafe_ctor(PyObject* args, PyObject* kwargs) {
668669
// See Note [Ensuring sparse values and indices match devices]
669670
Tensor crow_indices = internal_new_from_data(values.options().dtype(kInt), r.pyobject(ARG_CROW_INDICES),
670671
/*copy_variables=*/false, /*copy_numpy=*/true,
671-
/*type_inference=*/true);
672+
/*type_inference=*/true, /*device_inference=*/true);
672673

673674
// See Note [Ensuring sparse values and indices match devices]
674675
Tensor col_indices = internal_new_from_data(values.options().dtype(kInt), r.pyobject(ARG_COL_INDICES),
675676
/*copy_variables=*/false, /*copy_numpy=*/true,
676-
/*type_inference=*/true);
677+
/*type_inference=*/true, /*device_inference=*/true);
677678

678679
return at::_sparse_csr_tensor_unsafe(crow_indices, col_indices, values, r.intlist(ARG_SIZE), values.options().layout(at::kSparseCsr)).set_requires_grad(r.toBool(ARG_REQUIRES_GRAD));
679680
}
@@ -692,6 +693,12 @@ Tensor _sparse_csr_tensor_unsafe_ctor(PyObject* args, PyObject* kwargs) {
692693
// should accept even ordinary index sequences (and just make sure we write them
693694
// into the correct device). values is the ONLY way we know that the index
694695
// tensor should go to CUDA, so we have to get the information in somehow.
696+
//
697+
// However, there is an inverse problem: if the input indices are a tensor,
698+
// we SHOULD NOT do a device-to-device copy to make them line up with values,
699+
// as this is a performance footgun; instead, we should return it as is and
700+
// let the final constructor error. device_inference=true toggles this
701+
// behavior.
695702

696703
Tensor sparse_coo_tensor_ctor(PyObject* args, PyObject* kwargs) {
697704
static PythonArgParser parser({
@@ -713,7 +720,7 @@ Tensor sparse_coo_tensor_ctor(PyObject* args, PyObject* kwargs) {
713720
// See Note [Ensuring sparse values and indices match devices]
714721
Tensor indices = internal_new_from_data(values.options().dtype(kLong), r.pyobject(0),
715722
/*copy_variables=*/false, /*copy_numpy=*/true,
716-
/*type_inference=*/false);
723+
/*type_inference=*/false, /*device_inference=*/true);
717724
return at::sparse_coo_tensor(indices, values, values.options().layout(at::kSparse)).set_requires_grad(r.toBool(4));
718725
} else if (r.idx == 1) {
719726
bool type_inference = r.isNone(3);
@@ -725,7 +732,7 @@ Tensor sparse_coo_tensor_ctor(PyObject* args, PyObject* kwargs) {
725732
// See Note [Ensuring sparse values and indices match devices]
726733
Tensor indices = internal_new_from_data(values.options().dtype(kLong), r.pyobject(0),
727734
/*copy_variables=*/false, /*copy_numpy=*/true,
728-
/*type_inference=*/false);
735+
/*type_inference=*/false, /*device_inference=*/true);
729736
return at::sparse_coo_tensor(indices, values, r.intlist(2), values.options().layout(at::kSparse)).set_requires_grad(r.toBool(5));
730737
} else if (r.idx == 2) {
731738
auto options = parsed_options(r, 1, 2);
@@ -760,7 +767,7 @@ Tensor _sparse_coo_tensor_unsafe_ctor(PyObject* args, PyObject* kwargs) {
760767
// See Note [Ensuring sparse values and indices match devices]
761768
Tensor indices = internal_new_from_data(values.options().dtype(kLong), r.pyobject(ARG_INDICES),
762769
/*copy_variables=*/false, /*copy_numpy=*/true,
763-
/*type_inference=*/false);
770+
/*type_inference=*/false, /*device_inference=*/true);
764771
return at::_sparse_coo_tensor_unsafe(indices, values, r.intlist(ARG_SIZE), values.options().layout(at::kSparse)).set_requires_grad(r.toBool(ARG_REQUIRES_GRAD));
765772
}
766773

@@ -780,7 +787,7 @@ void _validate_sparse_coo_tensor_args(PyObject* args, PyObject* kwargs) {
780787
// See Note [Ensuring sparse values and indices match devices]
781788
Tensor indices = internal_new_from_data(
782789
values.options().dtype(kLong), r.pyobject(0),
783-
/*copy_variables=*/false, /*copy_numpy=*/true, /*type_inference=*/false);
790+
/*copy_variables=*/false, /*copy_numpy=*/true, /*type_inference=*/false, /*device_inference=*/true);
784791
at::native::_validate_sparse_coo_tensor_args(indices, values, r.intlist(2));
785792
}
786793

@@ -798,10 +805,10 @@ void _validate_sparse_csr_tensor_args(PyObject* args, PyObject* kwargs) {
798805
// See Note [Ensuring sparse values and indices match devices]
799806
Tensor crow_indices = internal_new_from_data(
800807
values.options().dtype(kInt), r.pyobject(0),
801-
/*copy_variables=*/false, /*copy_numpy=*/true, /*type_inference=*/true);
808+
/*copy_variables=*/false, /*copy_numpy=*/true, /*type_inference=*/true, /*device_inference=*/true);
802809
Tensor col_indices = internal_new_from_data(
803810
values.options().dtype(kInt), r.pyobject(1),
804-
/*copy_variables=*/false, /*copy_numpy=*/true, /*type_inference=*/true);
811+
/*copy_variables=*/false, /*copy_numpy=*/true, /*type_inference=*/true, /*device_inference=*/true);
805812

806813
at::native::_validate_sparse_csr_tensor_args(crow_indices, col_indices, values, r.intlist(3));
807814
}
@@ -833,7 +840,8 @@ Tensor tensor_ctor(PyObject* args, PyObject* kwargs) {
833840
/*copy_variables=*/true,
834841
/*copy_numpy=*/true,
835842
/*type_inference=*/type_inference,
836-
pin_memory);
843+
/*device_inference=*/false,
844+
/*pin_memory=*/pin_memory);
837845
auto names = r.toDimnameListOptional(5);
838846
if (names) {
839847
at::namedinference::propagate_names(new_tensor, *names, /*validate_names=*/true);

0 commit comments

Comments
 (0)