@@ -203,6 +203,7 @@ Tensor internal_new_from_data(
203203 // It is possible for type_inference to be true and dtype to be present
204204 // in options; in this case the inferred type will take precedence.
205205 bool type_inference,
206+ bool device_inference = false ,
206207 bool pin_memory = false ) {
207208 if (THPUtils_checkString (data)) {
208209 throw TypeError (" new(): invalid data type '%s'" , Py_TYPE (data)->tp_name );
@@ -220,7 +221,7 @@ Tensor internal_new_from_data(
220221 // infer the scalar type and device type; it's not expected to infer the layout since these constructors
221222 // are defined per-layout-type (e.g. tensor vs sparse_coo_tensor).
222223 auto inferred_scalar_type = type_inference ? var.scalar_type () : scalar_type;
223- auto device = options.device_opt ().value_or (var.device ());
224+ auto device = device_inference ? var. device () : options.device_opt ().value_or (var.device ());
224225 pybind11::gil_scoped_release no_gil;
225226 maybe_initialize_cuda (device);
226227 return var.to (device, inferred_scalar_type, /* non_blocking=*/ false , /* copy=*/ copy_variables);
@@ -607,12 +608,12 @@ Tensor sparse_csr_tensor_ctor(PyObject* args, PyObject* kwargs) {
607608 Tensor crow_indices = internal_new_from_data (values.options ().dtype (kInt ),
608609 r.pyobject (CROW_INDICES_ARG),
609610 /* copy_variables=*/ false , /* copy_numpy=*/ true ,
610- /* type_inference=*/ true );
611+ /* type_inference=*/ true , /* device_inference= */ true );
611612 // See Note [Ensuring sparse values and indices match devices]
612613 Tensor col_indices = internal_new_from_data (values.options ().dtype (kInt ),
613614 r.pyobject (COL_INDICES_ARG),
614615 /* copy_variables=*/ false , /* copy_numpy=*/ true ,
615- /* type_inference=*/ true );
616+ /* type_inference=*/ true , /* device_inference= */ true );
616617
617618 return at::sparse_csr_tensor (crow_indices, col_indices, values, r.intlist (SIZE_ARRAY_ARG),
618619 values.options ().layout (at::kSparseCsr )).set_requires_grad (r.toBool (REQ_GRAD_ARG));
@@ -629,12 +630,12 @@ Tensor sparse_csr_tensor_ctor(PyObject* args, PyObject* kwargs) {
629630 Tensor crow_indices = internal_new_from_data (values.options ().dtype (kInt ),
630631 r.pyobject (CROW_INDICES_ARG),
631632 /* copy_variables=*/ false , /* copy_numpy=*/ true ,
632- /* type_inference=*/ true );
633+ /* type_inference=*/ true , /* device_inference= */ true );
633634 // See Note [Ensuring sparse values and indices match devices]
634635 Tensor col_indices = internal_new_from_data (values.options ().dtype (kInt ),
635636 r.pyobject (COL_INDICES_ARG),
636637 /* copy_variables=*/ false , /* copy_numpy=*/ true ,
637- /* type_inference=*/ true );
638+ /* type_inference=*/ true , /* device_inference= */ true );
638639 return at::sparse_csr_tensor (crow_indices, col_indices, values,
639640 values.options ().layout (at::kSparseCsr )).set_requires_grad (r.toBool (REQ_GRAD_ARG));
640641 }
@@ -668,12 +669,12 @@ Tensor _sparse_csr_tensor_unsafe_ctor(PyObject* args, PyObject* kwargs) {
668669 // See Note [Ensuring sparse values and indices match devices]
669670 Tensor crow_indices = internal_new_from_data (values.options ().dtype (kInt ), r.pyobject (ARG_CROW_INDICES),
670671 /* copy_variables=*/ false , /* copy_numpy=*/ true ,
671- /* type_inference=*/ true );
672+ /* type_inference=*/ true , /* device_inference= */ true );
672673
673674 // See Note [Ensuring sparse values and indices match devices]
674675 Tensor col_indices = internal_new_from_data (values.options ().dtype (kInt ), r.pyobject (ARG_COL_INDICES),
675676 /* copy_variables=*/ false , /* copy_numpy=*/ true ,
676- /* type_inference=*/ true );
677+ /* type_inference=*/ true , /* device_inference= */ true );
677678
678679 return at::_sparse_csr_tensor_unsafe (crow_indices, col_indices, values, r.intlist (ARG_SIZE), values.options ().layout (at::kSparseCsr )).set_requires_grad (r.toBool (ARG_REQUIRES_GRAD));
679680}
@@ -692,6 +693,12 @@ Tensor _sparse_csr_tensor_unsafe_ctor(PyObject* args, PyObject* kwargs) {
692693// should accept even ordinary index sequences (and just make sure we write them
693694// into the correct device). values is the ONLY way we know that the index
694695// tensor should go to CUDA, so we have to get the information in somehow.
696+ //
697+ // However, there is an inverse problem: if the input indices are a tensor,
698+ // we SHOULD NOT do a device-to-device copy to make them line up with values,
699+ // as this is a performance footgun; instead, we should return it as is and
700+ // let the final constructor error. device_inference=true toggles this
701+ // behavior.
695702
696703Tensor sparse_coo_tensor_ctor (PyObject* args, PyObject* kwargs) {
697704 static PythonArgParser parser ({
@@ -713,7 +720,7 @@ Tensor sparse_coo_tensor_ctor(PyObject* args, PyObject* kwargs) {
713720 // See Note [Ensuring sparse values and indices match devices]
714721 Tensor indices = internal_new_from_data (values.options ().dtype (kLong ), r.pyobject (0 ),
715722 /* copy_variables=*/ false , /* copy_numpy=*/ true ,
716- /* type_inference=*/ false );
723+ /* type_inference=*/ false , /* device_inference= */ true );
717724 return at::sparse_coo_tensor (indices, values, values.options ().layout (at::kSparse )).set_requires_grad (r.toBool (4 ));
718725 } else if (r.idx == 1 ) {
719726 bool type_inference = r.isNone (3 );
@@ -725,7 +732,7 @@ Tensor sparse_coo_tensor_ctor(PyObject* args, PyObject* kwargs) {
725732 // See Note [Ensuring sparse values and indices match devices]
726733 Tensor indices = internal_new_from_data (values.options ().dtype (kLong ), r.pyobject (0 ),
727734 /* copy_variables=*/ false , /* copy_numpy=*/ true ,
728- /* type_inference=*/ false );
735+ /* type_inference=*/ false , /* device_inference= */ true );
729736 return at::sparse_coo_tensor (indices, values, r.intlist (2 ), values.options ().layout (at::kSparse )).set_requires_grad (r.toBool (5 ));
730737 } else if (r.idx == 2 ) {
731738 auto options = parsed_options (r, 1 , 2 );
@@ -760,7 +767,7 @@ Tensor _sparse_coo_tensor_unsafe_ctor(PyObject* args, PyObject* kwargs) {
760767 // See Note [Ensuring sparse values and indices match devices]
761768 Tensor indices = internal_new_from_data (values.options ().dtype (kLong ), r.pyobject (ARG_INDICES),
762769 /* copy_variables=*/ false , /* copy_numpy=*/ true ,
763- /* type_inference=*/ false );
770+ /* type_inference=*/ false , /* device_inference= */ true );
764771 return at::_sparse_coo_tensor_unsafe (indices, values, r.intlist (ARG_SIZE), values.options ().layout (at::kSparse )).set_requires_grad (r.toBool (ARG_REQUIRES_GRAD));
765772}
766773
@@ -780,7 +787,7 @@ void _validate_sparse_coo_tensor_args(PyObject* args, PyObject* kwargs) {
780787 // See Note [Ensuring sparse values and indices match devices]
781788 Tensor indices = internal_new_from_data (
782789 values.options ().dtype (kLong ), r.pyobject (0 ),
783- /* copy_variables=*/ false , /* copy_numpy=*/ true , /* type_inference=*/ false );
790+ /* copy_variables=*/ false , /* copy_numpy=*/ true , /* type_inference=*/ false , /* device_inference= */ true );
784791 at::native::_validate_sparse_coo_tensor_args (indices, values, r.intlist (2 ));
785792}
786793
@@ -798,10 +805,10 @@ void _validate_sparse_csr_tensor_args(PyObject* args, PyObject* kwargs) {
798805 // See Note [Ensuring sparse values and indices match devices]
799806 Tensor crow_indices = internal_new_from_data (
800807 values.options ().dtype (kInt ), r.pyobject (0 ),
801- /* copy_variables=*/ false , /* copy_numpy=*/ true , /* type_inference=*/ true );
808+ /* copy_variables=*/ false , /* copy_numpy=*/ true , /* type_inference=*/ true , /* device_inference= */ true );
802809 Tensor col_indices = internal_new_from_data (
803810 values.options ().dtype (kInt ), r.pyobject (1 ),
804- /* copy_variables=*/ false , /* copy_numpy=*/ true , /* type_inference=*/ true );
811+ /* copy_variables=*/ false , /* copy_numpy=*/ true , /* type_inference=*/ true , /* device_inference= */ true );
805812
806813 at::native::_validate_sparse_csr_tensor_args (crow_indices, col_indices, values, r.intlist (3 ));
807814}
@@ -833,7 +840,8 @@ Tensor tensor_ctor(PyObject* args, PyObject* kwargs) {
833840 /* copy_variables=*/ true ,
834841 /* copy_numpy=*/ true ,
835842 /* type_inference=*/ type_inference,
836- pin_memory);
843+ /* device_inference=*/ false ,
844+ /* pin_memory=*/ pin_memory);
837845 auto names = r.toDimnameListOptional (5 );
838846 if (names) {
839847 at::namedinference::propagate_names (new_tensor, *names, /* validate_names=*/ true );
0 commit comments