@@ -727,12 +727,12 @@ Tensor bilinear(const Tensor& input1, const Tensor& input2, const Tensor& weight
727727Tensor tensordot (const Tensor& input1, const Tensor& input2, IntArrayRef dims1, IntArrayRef dims2) {
728728 TORCH_CHECK (dims1.size () == dims2.size (), " both dimension lists should have same length" );
729729 TORCH_CHECK (input1.scalar_type () == input2.scalar_type (), " both inputs should have same dtype" );
730- int64_t csize = 1 ; // total size of the contracted dimensions
730+ SymInt csize = 1 ; // total size of the contracted dimensions
731731 Tensor t1 = input1;
732732 Tensor t2 = input2;
733733 for (const auto i : c10::irange (dims1.size ())) {
734- int s1 = input1.size (dims1[i]);
735- int s2 = input2.size (dims2[i]);
734+ SymInt s1 = input1.sym_size (dims1[i]);
735+ SymInt s2 = input2.sym_size (dims2[i]);
736736 if (s2 == 1 ) { // broadcasted dimensions can be summed right away
737737 t1 = t1.sum (dims1[i], true );
738738 } else if (s1 == 1 ) {
@@ -746,19 +746,20 @@ Tensor tensordot(const Tensor& input1, const Tensor& input2, IntArrayRef dims1,
746746
747747 auto cdims1 = at::dim_list_to_bitset (dims1, input1.dim ());
748748 auto cdims2 = at::dim_list_to_bitset (dims2, input2.dim ());
749- std::vector<int64_t > p1, p2, rsizes; // p1, p2: input permutations, rsizes: sizes of the result
749+ std::vector<int64_t > p1, p2; // p1, p2: input permutations
750+ std::vector<SymInt> rsizes; // rsizes: sizes of the result
750751 p1.reserve (input1.dim ());
751752 p2.reserve (input2.dim ());
752753 rsizes.reserve (input1.dim () + input2.dim () - (int64_t ) dims1.size ());
753- int64_t size1 = 1 ; // number of non-contracted elements in input1
754- int64_t size2 = 1 ; // number of non-contracted elements in input2
754+ SymInt size1 = 1 ; // number of non-contracted elements in input1
755+ SymInt size2 = 1 ; // number of non-contracted elements in input2
755756
756757 // fill the permutations and compute sizes
757758 for (const auto i : c10::irange (input1.dim ())) {
758759 if (! cdims1[i]) {
759760 p1.emplace_back (i);
760- size1 *= t1.size (i);
761- rsizes.emplace_back (t1.size (i));
761+ size1 *= t1.sym_size (i);
762+ rsizes.emplace_back (t1.sym_size (i));
762763 }
763764 }
764765 for (const auto x : dims1) {
@@ -770,15 +771,15 @@ Tensor tensordot(const Tensor& input1, const Tensor& input2, IntArrayRef dims1,
770771 for (const auto i : c10::irange (input2.dim ())) {
771772 if (! cdims2[i]) {
772773 p2.emplace_back (i);
773- size2 *= t2.size (i);
774- rsizes.emplace_back (t2.size (i));
774+ size2 *= t2.sym_size (i);
775+ rsizes.emplace_back (t2.sym_size (i));
775776 }
776777 }
777778 // permut and reshape for matrix multiplication
778- t1 = t1.permute (p1).reshape ({size1, csize});
779- t2 = t2.permute (p2).reshape ({csize, size2});
779+ t1 = t1.permute (p1).reshape_symint ({size1, csize});
780+ t2 = t2.permute (p2).reshape_symint ({csize, size2});
780781 // multiply and reshape to target size
781- return at::mm (t1, t2).reshape (rsizes);
782+ return at::mm (t1, t2).reshape_symint (rsizes);
782783}
783784
784785Tensor &tensordot_out (const Tensor& input1, const Tensor& input2, IntArrayRef dims1, IntArrayRef dims2, Tensor& result) {
0 commit comments