@@ -111,7 +111,6 @@ std::vector<Dimname> unify_from_right(
111111 return result;
112112}
113113
114-
115114namespace namedinference {
116115
117116static std::bitset<dim_bitset_size>
@@ -128,51 +127,41 @@ static void assert_names_equal(DimnameList a, DimnameList b) {
128127 " . Please rename the out tensor's dims with `Tensor.rename`." );
129128}
130129
131- void propagate_names (TensorImpl* result, optional<DimnameList> names) {
132- if (!impl::has_names (result) && !names.has_value ()) {
133- return ;
134- }
135- if (!impl::has_names (result)) {
136- impl::internal_set_names_inplace (result, names);
137- return ;
138- }
139- assert_names_equal (
140- impl::get_names (result),
141- names.value_or (default_names (result->dim ())));
142- }
143-
144- void propagate_names (
145- Tensor& result,
146- optional<std::vector<Dimname>>&& maybe_names,
130+ Tensor& propagate_names_if_nonempty (Tensor& result,
131+ DimnameList maybe_names,
147132 bool validate_names) {
148- propagate_names (result.unsafeGetTensorImpl (), std::move (maybe_names), validate_names);
133+ propagate_names_if_nonempty (result.unsafeGetTensorImpl (), maybe_names, validate_names);
134+ return result;
149135}
150136
151- void propagate_names (
152- TensorImpl* result,
153- optional<std::vector<Dimname>>&& maybe_names,
137+ TensorImpl* propagate_names_if_nonempty (TensorImpl* result,
138+ DimnameList maybe_names,
154139 bool validate_names) {
155- if (!maybe_names) {
156- propagate_names (result, nullopt );
157- return ;
158- }
159- propagate_names (result, std::move (maybe_names.value ()), validate_names);
160- }
161-
162- void propagate_names (TensorImpl* result, std::vector<Dimname>&& names, bool validate_names) {
163- if (!impl::has_names (result)) {
164- impl::internal_set_names_inplace (result, std::move (names), validate_names);
165- return ;
140+ if (maybe_names.empty ()) {
141+ return result;
166142 }
167- assert_names_equal ( impl::get_names ( result), names );
143+ return propagate_names ( result, maybe_names, validate_names );
168144}
169145
170- void propagate_names (Tensor& result, optional<DimnameList> names) {
171- propagate_names (result.unsafeGetTensorImpl (), names);
146+ Tensor& propagate_names (Tensor& result, DimnameList names, bool validate_names) {
147+ propagate_names (result.unsafeGetTensorImpl (), names, validate_names);
148+ return result;
172149}
173150
174- void propagate_names (Tensor& result, std::vector<Dimname>&& names, bool validate_names) {
175- propagate_names (result.unsafeGetTensorImpl (), std::move (names), validate_names);
151+ TensorImpl* propagate_names (TensorImpl* result, DimnameList names, bool validate_names) {
152+ if (result->dim () > 0 ) {
153+ TORCH_INTERNAL_ASSERT (
154+ !names.empty (),
155+ " propagate_names: passed in empty names to propagate to result with" ,
156+ " shape " , result->sizes (), " . Empty names means that name inference did" ,
157+ " not occur; use `propagate_names_if_nonempty` instead of `propagate_names`." );
158+ }
159+ if (!impl::has_names (result)) {
160+ impl::internal_set_names_inplace (result, names, validate_names);
161+ } else {
162+ assert_names_equal (impl::get_names (result), names);
163+ }
164+ return result;
176165}
177166
178167void propagate_names_except (Tensor& result, const Tensor& src, IntArrayRef excluded_idxs) {
@@ -188,7 +177,7 @@ void propagate_names_except(Tensor& result, const Tensor& src, IntArrayRef exclu
188177 if (excluded_idxs.size () == 1 ) {
189178 std::vector<Dimname> outnames = src_names.vec ();
190179 outnames.erase (outnames.begin () + maybe_wrap_dim (excluded_idxs[0 ], src_dim));
191- propagate_names (result, std::move ( outnames), /* validate_names= */ false );
180+ propagate_names (result, outnames);
192181 return ;
193182 }
194183
@@ -200,7 +189,7 @@ void propagate_names_except(Tensor& result, const Tensor& src, IntArrayRef exclu
200189 outnames.push_back (src_names[dim]);
201190 }
202191 }
203- propagate_names (result, std::move ( outnames), /* validate_names= */ false );
192+ propagate_names (result, outnames);
204193}
205194
206195void propagate_names_for_reduction (Tensor& result, const Tensor& src, IntArrayRef reduced_dims, bool keepdim) {
@@ -223,12 +212,15 @@ void propagate_names(TensorImpl* result, TensorImpl* src) {
223212 if (result == src) {
224213 return ;
225214 }
226- propagate_names (result, impl::get_opt_names (src));
215+ if (!impl::has_names (result) && !impl::has_names (src)) {
216+ return ;
217+ }
218+ propagate_names (result, impl::get_names (src));
227219}
228220
229- optional< std::vector<Dimname> > compute_squeeze_outnames (const Tensor& tensor) {
221+ std::vector<Dimname> compute_squeeze_outnames (const Tensor& tensor) {
230222 if (!tensor.has_names ()) {
231- return nullopt ;
223+ return {} ;
232224 }
233225 std::vector<Dimname> outnames;
234226 auto tensor_names = tensor.names ();
@@ -371,7 +363,7 @@ void propagate_names_for_addmv(
371363 }
372364 auto mv_outnames = compute_matmul_outnames (impl::get_names (mat), impl::get_names (vec));
373365 auto add_outnames = unify_from_right (mv_outnames, impl::get_names (bias));
374- propagate_names (result, std::move ( add_outnames), /* validate_names= */ false );
366+ propagate_names (result, add_outnames);
375367}
376368
377369void propagate_names_for_addmm (
@@ -385,7 +377,7 @@ void propagate_names_for_addmm(
385377 }
386378 auto mm_outnames = compute_matmul_outnames (impl::get_names (m1), impl::get_names (m2));
387379 auto add_outnames = unify_from_right (mm_outnames, impl::get_names (bias));
388- propagate_names (result, std::move ( add_outnames), /* validate_names= */ false );
380+ propagate_names (result, add_outnames);
389381}
390382
391383void check_names_for_dot (
@@ -415,24 +407,24 @@ void propagate_names_for_expand(Tensor& result, const Tensor& self) {
415407 self.opt_names ()->begin (),
416408 self.opt_names ()->end (),
417409 outnames.begin () + result_dim - self.dim ());
418- propagate_names (result, std::move ( outnames), /* validate_names= */ false );
410+ propagate_names (result, outnames);
419411}
420412
421- optional< std::vector<Dimname> > compute_broadcast_outnames (
413+ std::vector<Dimname> compute_broadcast_outnames (
422414 const Tensor& self,
423415 const Tensor& other) {
424416 if (!self.has_names () && !other.has_names ()) {
425- return nullopt ;
417+ return {} ;
426418 }
427419 return unify_from_right (self.names (), other.names ());
428420}
429421
430- optional< std::vector<Dimname> > broadcast_to_outnames (
422+ std::vector<Dimname> broadcast_to_outnames (
431423 const Tensor& tensor,
432424 const Tensor& reference_tensor,
433425 const char * op_name) {
434426 if (!tensor.has_names () && !reference_tensor.has_names ()) {
435- return nullopt ;
427+ return {} ;
436428 }
437429 auto reference_names = reference_tensor.names ();
438430 auto tensor_names = tensor.names ();
@@ -445,9 +437,9 @@ optional<std::vector<Dimname>> broadcast_to_outnames(
445437 return unify_from_right (reference_names, tensor_names);
446438}
447439
448- optional< std::vector<Dimname> > compute_cat_outnames (TensorList tensors) {
440+ std::vector<Dimname> compute_cat_outnames (TensorList tensors) {
449441 if (!at::has_names (tensors)) {
450- return nullopt ;
442+ return {} ;
451443 }
452444 std::vector<Dimname> result;
453445 for (const auto & tensor : tensors) {
@@ -461,20 +453,20 @@ optional<std::vector<Dimname>> compute_cat_outnames(TensorList tensors) {
461453 return result;
462454}
463455
464- optional< std::vector<Dimname> > compute_matmul_outnames (
456+ std::vector<Dimname> compute_matmul_outnames (
465457 const Tensor& self,
466458 const Tensor& other) {
467459 if (!self.has_names () && !other.has_names ()) {
468- return nullopt ;
460+ return {} ;
469461 }
470462 return compute_matmul_outnames (self.names (), other.names ());
471463}
472464
473- optional< std::vector<Dimname> > compute_cdist_outnames (
465+ std::vector<Dimname> compute_cdist_outnames (
474466 const Tensor& self,
475467 const Tensor& other) {
476468 if (!self.has_names () && !other.has_names ()) {
477- return nullopt ;
469+ return {} ;
478470 }
479471 const auto self_names = self.names ();
480472 const auto other_names = other.names ();
@@ -496,24 +488,24 @@ optional<std::vector<Dimname>> compute_cdist_outnames(
496488 return result.toDimnameVec ();
497489}
498490
499- optional< std::vector<Dimname> > compute_bmm_outnames (
491+ std::vector<Dimname> compute_bmm_outnames (
500492 Tensor& result,
501493 const Tensor& self,
502494 const Tensor& other) {
503495 if (!result.has_names () && !self.has_names () && !other.has_names ()) {
504- return nullopt ;
496+ return {} ;
505497 }
506498 return compute_matmul_outnames (self.names (), other.names ());
507499}
508500
509- optional< std::vector<Dimname> > compute_baddbmm_outnames (
501+ std::vector<Dimname> compute_baddbmm_outnames (
510502 TensorImpl* result,
511503 TensorImpl* batch1,
512504 TensorImpl* batch2,
513505 TensorImpl* bias) {
514506 if (!impl::has_names (result) && !impl::has_names (batch1) &&
515507 !impl::has_names (batch2) && !impl::has_names (bias)) {
516- return nullopt ;
508+ return {} ;
517509 }
518510 auto bmm_names = compute_matmul_outnames (
519511 impl::get_names (batch1), impl::get_names (batch2));
0 commit comments