@@ -128,7 +128,19 @@ TensorIteratorConfig& TensorIteratorConfig::add_borrowed_input(const TensorBase&
128128
129129TensorIteratorConfig& TensorIteratorConfig::declare_static_dtype_and_device (ScalarType dtype, Device device) {
130130 TORCH_CHECK (!check_all_same_dtype_, " check_all_same_dtype(false) must be called before declare_static_dtype(...)" );
131- static_dtype_and_device_ = c10::make_optional (std::make_pair (dtype, device));
131+ static_dtype_ = dtype;
132+ static_device_ = device;
133+ return *this ;
134+ }
135+
136+ TensorIteratorConfig& TensorIteratorConfig::declare_static_dtype (ScalarType dtype) {
137+ TORCH_CHECK (!check_all_same_dtype_, " check_all_same_dtype(false) must be called before declare_static_dtype(...)" );
138+ static_dtype_ = dtype;
139+ return *this ;
140+ }
141+
142+ TensorIteratorConfig& TensorIteratorConfig::declare_static_device (Device device) {
143+ static_device_ = device;
132144 return *this ;
133145}
134146
@@ -327,12 +339,20 @@ void TensorIteratorBase::compute_types(const TensorIteratorConfig& config) {
327339 // the device it should be allocated on.
328340 if (!op.is_type_defined ()) {
329341 TORCH_INTERNAL_ASSERT (op.is_output , " Found type undefined input tensor!" );
330- if (config. static_dtype_and_device_ . has_value ()) {
331- op. target_dtype = config.static_dtype_and_device_ -> first ;
332- op.device = config.static_dtype_and_device_ -> second ;
342+
343+ if ( config.static_dtype_ . has_value ()) {
344+ op.target_dtype = config.static_dtype_ . value () ;
333345 } else {
334- TORCH_INTERNAL_ASSERT (config.check_all_same_device_ );
335346 has_undefined_outputs = true ;
347+ }
348+
349+ if (config.static_device_ .has_value ()) {
350+ op.device = config.static_device_ .value ();
351+ } else {
352+ TORCH_INTERNAL_ASSERT (config.check_all_same_device_ );
353+ }
354+
355+ if (has_undefined_outputs || !op.device .has_value ()) {
336356 continue ;
337357 }
338358 }
@@ -418,12 +438,21 @@ void TensorIteratorBase::compute_types(const TensorIteratorConfig& config) {
418438 // - checks that all tensors are on the same device, if requested
419439 // - checks that the common dtype can safely cast to each output, if requested
420440 // - creates temporaries for CPU operations, if needed and requested
441+ common_device_ = common_device;
421442 int max_cpu_scalars_on_non_cpu = config.allow_cpu_scalars_ ? 1 : 0 ;
422443 int current_cpu_scalars_on_non_cpu = 0 ;
423444 for (auto & op : operands_) {
424- if (!op.is_type_defined ()) {
445+ bool is_type_defined = op.is_type_defined ();
446+ bool is_device_defined = op.is_device_defined ();
447+
448+ if (!is_type_defined) {
425449 op.target_dtype = common_dtype_;
450+ }
451+ if (!is_device_defined) {
426452 op.device = common_device;
453+ }
454+
455+ if (!is_type_defined && !is_device_defined) {
427456 continue ;
428457 }
429458
@@ -441,10 +470,10 @@ void TensorIteratorBase::compute_types(const TensorIteratorConfig& config) {
441470 TORCH_CHECK (current_cpu_scalars_on_non_cpu < max_cpu_scalars_on_non_cpu,
442471 " Trying to pass too many CPU scalars to non-CPU kernel!" );
443472 ++current_cpu_scalars_on_non_cpu;
444- } else if (op.device != common_device) {
473+ } else if (op.device . value () != common_device) {
445474 TORCH_CHECK (false ,
446475 " Expected all tensors to be on the same device, but "
447- " found at least two devices, " , common_device, " and " , op.device , " !" );
476+ " found at least two devices, " , common_device, " and " , op.device . value () , " !" );
448477 }
449478 }
450479
@@ -490,7 +519,6 @@ void TensorIteratorBase::compute_types(const TensorIteratorConfig& config) {
490519 op.target_dtype = common_dtype_;
491520 }
492521 }
493- common_device_ = common_device;
494522 }
495523}
496524
@@ -864,7 +892,7 @@ void TensorIteratorBase::build_comparison_op(
864892 // want the output to be bool. Otherwise (e.g. 'torch.eq(a, b, out=c)') we
865893 // don't coerce the output.
866894 if (!out.defined ()) {
867- config.declare_static_dtype_and_device (kBool , a. device () );
895+ config.declare_static_dtype (kBool );
868896 }
869897
870898 // Note [special-case bool outputs]
@@ -943,7 +971,8 @@ void TensorIteratorBase::build_unary_force_boolean_op(const TensorBase& out, con
943971 build (TensorIteratorConfig ()
944972 .set_check_mem_overlap (true )
945973 .check_all_same_dtype (false )
946- .declare_static_dtype_and_device (at::kBool , a.device ())
974+ .declare_static_dtype (at::kBool )
975+ .declare_static_device (a.device ())
947976 .add_owned_output (out)
948977 .add_owned_input (a));
949978}
0 commit comments