@@ -180,7 +180,8 @@ at::Tensor& embedding_bag_nbit_impl(
180180 bool pruned_weights,
181181 const c10::optional<at::Tensor>& per_sample_weights_,
182182 const c10::optional<at::Tensor>& compressed_indices_mapping,
183- bool include_last_offset) {
183+ bool include_last_offset,
184+ bool is_embedding_op) {
184185 TORCH_CHECK (weight.dim () == 2 );
185186 TORCH_CHECK (offsets.dim () == 1 );
186187
@@ -226,11 +227,14 @@ at::Tensor& embedding_bag_nbit_impl(
226227 offsets_include_last_val[M] = indices.numel ();
227228 offsets_data = offsets_include_last_val.data ();
228229 }
229-
230- const std::vector<int64_t > shape = {output_size, D};
230+ std::vector<int64_t > shape;
231+ if (indices.dim () == 2 && is_embedding_op) {
232+ const auto indices_sizes = indices.sizes ();
233+ shape = {indices_sizes[0 ], indices_sizes[1 ], D};
234+ } else {
235+ shape = {output_size, D};
236+ }
231237 at::native::resize_ (output, shape, c10::nullopt );
232-
233-
234238#ifdef USE_FBGEMM
235239 const auto indices_data = indices.data_ptr <IndexType>();
236240 const auto weight_data = weight.data_ptr <uint8_t >();
@@ -506,7 +510,6 @@ at::Tensor& embedding_bag_byte_helper(
506510 " embedding_bag_byte operator: input is 2D, then offsets has to be None, as input is treated is a mini-batch of fixed length sequences." );
507511
508512 offsets = c10::MaybeOwned<at::Tensor>::owned (at::arange (0 , indices.numel (), indices.sizes ()[1 ], indices.scalar_type ()));
509-
510513 } else {
511514 TORCH_CHECK (
512515 offsets_in.has_value (),
@@ -590,7 +593,8 @@ at::Tensor& _embedding_bag_nbit_helper(
590593 bool pruned_weights,
591594 const c10::optional<at::Tensor>& per_sample_weights_,
592595 const c10::optional<at::Tensor>& compressed_indices_mapping,
593- bool include_last_offset) {
596+ bool include_last_offset,
597+ bool is_embedding_op) {
594598 c10::MaybeOwned<at::Tensor> offsets;
595599 TORCH_CHECK (
596600 bit_width == 4 || bit_width == 2 ,
@@ -603,7 +607,7 @@ at::Tensor& _embedding_bag_nbit_helper(
603607
604608 // For embedding_bag operator with 2D indices, we need to set the offsets
605609 // explicitly here.
606- if (indices.dim () == 2 ) {
610+ if (indices.dim () == 2 && !is_embedding_op ) {
607611 TORCH_CHECK (
608612 !offsets_in.has_value (),
609613 " embedding_bag_4bit/embedding_bag_2bit operator: input is 2D, then offsets has to be None, as input is treated is a mini-batch of fixed length sequences." );
@@ -644,7 +648,8 @@ at::Tensor& _embedding_bag_nbit_helper(
644648 pruned_weights,
645649 per_sample_weights_,
646650 compressed_indices_mapping,
647- include_last_offset);
651+ include_last_offset,
652+ is_embedding_op);
648653 } else if (
649654 indices.scalar_type () == at::kInt && offsets->scalar_type () == at::kLong ) {
650655 return embedding_bag_nbit_impl<int , int64_t >(
@@ -656,7 +661,8 @@ at::Tensor& _embedding_bag_nbit_helper(
656661 pruned_weights,
657662 per_sample_weights_,
658663 compressed_indices_mapping,
659- include_last_offset);
664+ include_last_offset,
665+ is_embedding_op);
660666 } else if (
661667 indices.scalar_type () == at::kLong && offsets->scalar_type () == at::kInt ) {
662668 return embedding_bag_nbit_impl<int64_t , int >(
@@ -668,7 +674,8 @@ at::Tensor& _embedding_bag_nbit_helper(
668674 pruned_weights,
669675 per_sample_weights_,
670676 compressed_indices_mapping,
671- include_last_offset);
677+ include_last_offset,
678+ is_embedding_op);
672679 }
673680 return embedding_bag_nbit_impl<int64_t , int64_t >(
674681 output,
@@ -679,7 +686,8 @@ at::Tensor& _embedding_bag_nbit_helper(
679686 pruned_weights,
680687 per_sample_weights_,
681688 compressed_indices_mapping,
682- include_last_offset);
689+ include_last_offset,
690+ is_embedding_op);
683691}
684692} // namespace
685693
@@ -710,7 +718,8 @@ at::Tensor PackedEmbeddingBagWeight::embeddingbag_4bit(
710718 bool pruned_weights,
711719 const c10::optional<at::Tensor>& per_sample_weights_,
712720 const c10::optional<at::Tensor>& compressed_indices_mapping,
713- bool include_last_offset) {
721+ bool include_last_offset,
722+ bool is_embedding_op) {
714723 if (per_sample_weights_.has_value ()) {
715724 TORCH_CHECK (
716725 (per_sample_weights_.value ().scalar_type () == at::kFloat ||
@@ -732,7 +741,8 @@ at::Tensor PackedEmbeddingBagWeight::embeddingbag_4bit(
732741 ? per_sample_weights_.value ().to (at::kFloat )
733742 : per_sample_weights_,
734743 compressed_indices_mapping,
735- include_last_offset);
744+ include_last_offset,
745+ is_embedding_op);
736746}
737747
738748namespace at {
@@ -792,7 +802,8 @@ Tensor& embedding_bag_4bit_rowwise_offsets_out(
792802 ? per_sample_weights_.value ().to (at::kFloat )
793803 : per_sample_weights_,
794804 compressed_indices_mapping,
795- include_last_offset);
805+ include_last_offset,
806+ false );
796807}
797808
798809Tensor& embedding_bag_2bit_rowwise_offsets_out (
@@ -826,7 +837,8 @@ Tensor& embedding_bag_2bit_rowwise_offsets_out(
826837 ? per_sample_weights_.value ().to (at::kFloat )
827838 : per_sample_weights_,
828839 compressed_indices_mapping,
829- include_last_offset);
840+ include_last_offset,
841+ false );
830842}
831843
832844namespace {
@@ -874,7 +886,6 @@ Tensor embedding_bag_4bit_rowwise_offsets(
874886 const c10::optional<Tensor>& per_sample_weights_,
875887 const c10::optional<Tensor>& compressed_indices_mapping,
876888 bool include_last_offset) {
877-
878889 auto output = create_empty_from (weight, at::kFloat );
879890 embedding_bag_4bit_rowwise_offsets_out (
880891 output,
@@ -886,8 +897,7 @@ Tensor embedding_bag_4bit_rowwise_offsets(
886897 pruned_weights,
887898 per_sample_weights_,
888899 compressed_indices_mapping,
889- include_last_offset
890- );
900+ include_last_offset);
891901 return output;
892902}
893903
@@ -901,7 +911,6 @@ Tensor embedding_bag_2bit_rowwise_offsets(
901911 const c10::optional<Tensor>& per_sample_weights_,
902912 const c10::optional<Tensor>& compressed_indices_mapping,
903913 bool include_last_offset) {
904-
905914 auto output = create_empty_from (weight, at::kFloat );
906915 embedding_bag_2bit_rowwise_offsets_out (
907916 output,
@@ -913,8 +922,7 @@ Tensor embedding_bag_2bit_rowwise_offsets(
913922 pruned_weights,
914923 per_sample_weights_,
915924 compressed_indices_mapping,
916- include_last_offset
917- );
925+ include_last_offset);
918926 return output;
919927}
920928
@@ -947,7 +955,8 @@ class QEmbeddingBag final {
947955 pruned_weights,
948956 per_sample_weights_,
949957 compressed_indices_mapping,
950- include_last_offset);
958+ include_last_offset,
959+ false );
951960 } else {
952961 TORCH_INTERNAL_ASSERT (
953962 " Currently only support 8-bit embedding_bag quantization" );
@@ -975,7 +984,15 @@ class QEmbedding final {
975984 c10::nullopt ,
976985 false /* include_last_offset */ ,
977986 true /* is_embedding_op */ );
978-
987+ } else if (bit_rate == 4 ) {
988+ return packed_weight->embeddingbag_4bit (
989+ indices,
990+ offsets,
991+ pruned_weights,
992+ c10::nullopt ,
993+ c10::nullopt ,
994+ false ,
995+ true );
979996 } else {
980997 TORCH_INTERNAL_ASSERT (
981998 " Currently only support 8-bit embedding quantization" );
@@ -995,6 +1012,9 @@ TORCH_LIBRARY_IMPL(quantized, CPU, m) {
9951012 m.impl (
9961013 TORCH_SELECTIVE_NAME (" quantized::embedding_byte" ),
9971014 TORCH_FN (QEmbedding<8 >::run));
1015+ m.impl (
1016+ TORCH_SELECTIVE_NAME (" quantized::embedding_4bit" ),
1017+ TORCH_FN (QEmbedding<4 >::run));
9981018
9991019 // Functions that work on at::Tensor packed weight.
10001020 m.impl (
0 commit comments