Skip to content

Commit b331752

Browse files
David Dangfacebook-github-bot
authored andcommitted
[Quant] Implemented 4 bit embedding op support; added corresponding test case (#69768)
Summary: Pull Request resolved: #69768 Support for the 4 embedding operator has been added. The support is analogous to the preexisting support for byte/8bit embedding. A corresponding test case was added to test_quantized_embedding_op.py Test Plan: In pytorch main dir, execute ``` python test/test_quantization.py TestStaticQuantizedModule.test_embedding_api ``` to run the series of tests, including the newly added test_embedding_4bit function Imported from OSS Reviewed By: jbschlosser Differential Revision: D33152673 fbshipit-source-id: bdcc2eb2e37de38fda3461ff3ebf1d2fb5e58071
1 parent 94abf12 commit b331752

File tree

5 files changed

+77
-59
lines changed

5 files changed

+77
-59
lines changed

aten/src/ATen/native/quantized/cpu/embedding_packed_params.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@ struct EmbeddingPackedParamsBase : public torch::jit::CustomClassHolder {
1919
bool pruned_weights,
2020
const c10::optional<at::Tensor>& per_sample_weights_,
2121
const c10::optional<at::Tensor>& compressed_indices_mapping,
22-
bool include_last_offset) = 0;
22+
bool include_last_offset,
23+
bool is_embedding_op) = 0;
2324

2425
virtual at::Tensor unpack() = 0;
2526

aten/src/ATen/native/quantized/cpu/fbgemm_utils.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -389,5 +389,6 @@ struct TORCH_API PackedEmbeddingBagWeight : public EmbeddingPackedParamsBase {
389389
bool pruned_weights,
390390
const c10::optional<at::Tensor>& per_sample_weights_,
391391
const c10::optional<at::Tensor>& compressed_indices_mapping,
392-
bool include_last_offset) override;
392+
bool include_last_offset,
393+
bool is_embedding_op) override;
393394
};

aten/src/ATen/native/quantized/cpu/qembeddingbag.cpp

Lines changed: 44 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,8 @@ at::Tensor& embedding_bag_nbit_impl(
180180
bool pruned_weights,
181181
const c10::optional<at::Tensor>& per_sample_weights_,
182182
const c10::optional<at::Tensor>& compressed_indices_mapping,
183-
bool include_last_offset) {
183+
bool include_last_offset,
184+
bool is_embedding_op) {
184185
TORCH_CHECK(weight.dim() == 2);
185186
TORCH_CHECK(offsets.dim() == 1);
186187

@@ -226,11 +227,14 @@ at::Tensor& embedding_bag_nbit_impl(
226227
offsets_include_last_val[M] = indices.numel();
227228
offsets_data = offsets_include_last_val.data();
228229
}
229-
230-
const std::vector<int64_t> shape = {output_size, D};
230+
std::vector<int64_t> shape;
231+
if(indices.dim() == 2 && is_embedding_op) {
232+
const auto indices_sizes = indices.sizes();
233+
shape = {indices_sizes[0], indices_sizes[1], D};
234+
} else {
235+
shape = {output_size, D};
236+
}
231237
at::native::resize_(output, shape, c10::nullopt);
232-
233-
234238
#ifdef USE_FBGEMM
235239
const auto indices_data = indices.data_ptr<IndexType>();
236240
const auto weight_data = weight.data_ptr<uint8_t>();
@@ -506,7 +510,6 @@ at::Tensor& embedding_bag_byte_helper(
506510
"embedding_bag_byte operator: input is 2D, then offsets has to be None, as input is treated is a mini-batch of fixed length sequences.");
507511

508512
offsets = c10::MaybeOwned<at::Tensor>::owned(at::arange(0, indices.numel(), indices.sizes()[1], indices.scalar_type()));
509-
510513
} else {
511514
TORCH_CHECK(
512515
offsets_in.has_value(),
@@ -590,7 +593,8 @@ at::Tensor& _embedding_bag_nbit_helper(
590593
bool pruned_weights,
591594
const c10::optional<at::Tensor>& per_sample_weights_,
592595
const c10::optional<at::Tensor>& compressed_indices_mapping,
593-
bool include_last_offset) {
596+
bool include_last_offset,
597+
bool is_embedding_op) {
594598
c10::MaybeOwned<at::Tensor> offsets;
595599
TORCH_CHECK(
596600
bit_width == 4 || bit_width == 2,
@@ -603,7 +607,7 @@ at::Tensor& _embedding_bag_nbit_helper(
603607

604608
// For embedding_bag operator with 2D indices, we need to set the offsets
605609
// explicitly here.
606-
if (indices.dim() == 2) {
610+
if (indices.dim() == 2 && !is_embedding_op) {
607611
TORCH_CHECK(
608612
!offsets_in.has_value(),
609613
"embedding_bag_4bit/embedding_bag_2bit operator: input is 2D, then offsets has to be None, as input is treated is a mini-batch of fixed length sequences.");
@@ -644,7 +648,8 @@ at::Tensor& _embedding_bag_nbit_helper(
644648
pruned_weights,
645649
per_sample_weights_,
646650
compressed_indices_mapping,
647-
include_last_offset);
651+
include_last_offset,
652+
is_embedding_op);
648653
} else if (
649654
indices.scalar_type() == at::kInt && offsets->scalar_type() == at::kLong) {
650655
return embedding_bag_nbit_impl<int, int64_t>(
@@ -656,7 +661,8 @@ at::Tensor& _embedding_bag_nbit_helper(
656661
pruned_weights,
657662
per_sample_weights_,
658663
compressed_indices_mapping,
659-
include_last_offset);
664+
include_last_offset,
665+
is_embedding_op);
660666
} else if (
661667
indices.scalar_type() == at::kLong && offsets->scalar_type() == at::kInt) {
662668
return embedding_bag_nbit_impl<int64_t, int>(
@@ -668,7 +674,8 @@ at::Tensor& _embedding_bag_nbit_helper(
668674
pruned_weights,
669675
per_sample_weights_,
670676
compressed_indices_mapping,
671-
include_last_offset);
677+
include_last_offset,
678+
is_embedding_op);
672679
}
673680
return embedding_bag_nbit_impl<int64_t, int64_t>(
674681
output,
@@ -679,7 +686,8 @@ at::Tensor& _embedding_bag_nbit_helper(
679686
pruned_weights,
680687
per_sample_weights_,
681688
compressed_indices_mapping,
682-
include_last_offset);
689+
include_last_offset,
690+
is_embedding_op);
683691
}
684692
} // namespace
685693

@@ -710,7 +718,8 @@ at::Tensor PackedEmbeddingBagWeight::embeddingbag_4bit(
710718
bool pruned_weights,
711719
const c10::optional<at::Tensor>& per_sample_weights_,
712720
const c10::optional<at::Tensor>& compressed_indices_mapping,
713-
bool include_last_offset) {
721+
bool include_last_offset,
722+
bool is_embedding_op) {
714723
if (per_sample_weights_.has_value()) {
715724
TORCH_CHECK(
716725
(per_sample_weights_.value().scalar_type() == at::kFloat ||
@@ -732,7 +741,8 @@ at::Tensor PackedEmbeddingBagWeight::embeddingbag_4bit(
732741
? per_sample_weights_.value().to(at::kFloat)
733742
: per_sample_weights_,
734743
compressed_indices_mapping,
735-
include_last_offset);
744+
include_last_offset,
745+
is_embedding_op);
736746
}
737747

738748
namespace at {
@@ -792,7 +802,8 @@ Tensor& embedding_bag_4bit_rowwise_offsets_out(
792802
? per_sample_weights_.value().to(at::kFloat)
793803
: per_sample_weights_,
794804
compressed_indices_mapping,
795-
include_last_offset);
805+
include_last_offset,
806+
false);
796807
}
797808

798809
Tensor& embedding_bag_2bit_rowwise_offsets_out(
@@ -826,7 +837,8 @@ Tensor& embedding_bag_2bit_rowwise_offsets_out(
826837
? per_sample_weights_.value().to(at::kFloat)
827838
: per_sample_weights_,
828839
compressed_indices_mapping,
829-
include_last_offset);
840+
include_last_offset,
841+
false);
830842
}
831843

832844
namespace {
@@ -874,7 +886,6 @@ Tensor embedding_bag_4bit_rowwise_offsets(
874886
const c10::optional<Tensor>& per_sample_weights_,
875887
const c10::optional<Tensor>& compressed_indices_mapping,
876888
bool include_last_offset) {
877-
878889
auto output = create_empty_from(weight, at::kFloat);
879890
embedding_bag_4bit_rowwise_offsets_out(
880891
output,
@@ -886,8 +897,7 @@ Tensor embedding_bag_4bit_rowwise_offsets(
886897
pruned_weights,
887898
per_sample_weights_,
888899
compressed_indices_mapping,
889-
include_last_offset
890-
);
900+
include_last_offset);
891901
return output;
892902
}
893903

@@ -901,7 +911,6 @@ Tensor embedding_bag_2bit_rowwise_offsets(
901911
const c10::optional<Tensor>& per_sample_weights_,
902912
const c10::optional<Tensor>& compressed_indices_mapping,
903913
bool include_last_offset) {
904-
905914
auto output = create_empty_from(weight, at::kFloat);
906915
embedding_bag_2bit_rowwise_offsets_out(
907916
output,
@@ -913,8 +922,7 @@ Tensor embedding_bag_2bit_rowwise_offsets(
913922
pruned_weights,
914923
per_sample_weights_,
915924
compressed_indices_mapping,
916-
include_last_offset
917-
);
925+
include_last_offset);
918926
return output;
919927
}
920928

@@ -947,7 +955,8 @@ class QEmbeddingBag final {
947955
pruned_weights,
948956
per_sample_weights_,
949957
compressed_indices_mapping,
950-
include_last_offset);
958+
include_last_offset,
959+
false);
951960
} else {
952961
TORCH_INTERNAL_ASSERT(
953962
"Currently only support 8-bit embedding_bag quantization");
@@ -975,7 +984,15 @@ class QEmbedding final {
975984
c10::nullopt,
976985
false /* include_last_offset */,
977986
true /* is_embedding_op */);
978-
987+
} else if (bit_rate == 4) {
988+
return packed_weight->embeddingbag_4bit(
989+
indices,
990+
offsets,
991+
pruned_weights,
992+
c10::nullopt,
993+
c10::nullopt,
994+
false,
995+
true);
979996
} else {
980997
TORCH_INTERNAL_ASSERT(
981998
"Currently only support 8-bit embedding quantization");
@@ -995,6 +1012,9 @@ TORCH_LIBRARY_IMPL(quantized, CPU, m) {
9951012
m.impl(
9961013
TORCH_SELECTIVE_NAME("quantized::embedding_byte"),
9971014
TORCH_FN(QEmbedding<8>::run));
1015+
m.impl(
1016+
TORCH_SELECTIVE_NAME("quantized::embedding_4bit"),
1017+
TORCH_FN(QEmbedding<4>::run));
9981018

9991019
// Functions that work on at::Tensor packed weight.
10001020
m.impl(

aten/src/ATen/native/quantized/library.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ TORCH_LIBRARY(quantized, m) {
140140
m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_byte(__torch__.torch.classes.quantized.EmbeddingPackedParamsBase weight, Tensor indices, Tensor? offsets=None, bool scale_grad_by_freq=False, int mode=0, bool pruned_weights=False, Tensor? per_sample_weights=None, Tensor? compressed_indices_mapping=None, bool include_last_offset=False) -> Tensor"));
141141
m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_4bit(__torch__.torch.classes.quantized.EmbeddingPackedParamsBase weight, Tensor indices, Tensor? offsets=None, bool scale_grad_by_freq=False, int mode=0, bool pruned_weights=False, Tensor? per_sample_weights=None, Tensor? compressed_indices_mapping=None, bool include_last_offset=False) -> Tensor"));
142142
m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_byte(__torch__.torch.classes.quantized.EmbeddingPackedParamsBase weight, Tensor indices, bool pruned_weights=False) -> Tensor"));
143+
m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_4bit(__torch__.torch.classes.quantized.EmbeddingPackedParamsBase weight, Tensor indices, bool pruned_weights=False) -> Tensor"));
143144
m.def(TORCH_SELECTIVE_SCHEMA("quantized::celu(Tensor self, float output_scale, int output_zero_point, Scalar alpha=1) -> Tensor"));
144145
m.def(TORCH_SELECTIVE_SCHEMA("quantized::group_norm(Tensor input, int num_groups, Tensor? weight, Tensor? bias, float eps, float output_scale, int output_zero_point) -> Tensor"));
145146
m.def(TORCH_SELECTIVE_SCHEMA("quantized::hardswish(Tensor input, float output_scale, int output_zero_point) -> Tensor"));

test/quantization/core/test_quantized_op.py

Lines changed: 28 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -3292,7 +3292,6 @@ def test_qlinear(self, batch_size, input_channels, output_channels, use_bias,
32923292
use_channelwise=st.booleans())
32933293
@override_qengines
32943294
def test_qlinear_unpack(self, W, use_channelwise):
3295-
32963295
W, (W_scale, W_zp, torch_type) = W
32973296
if use_channelwise:
32983297
output_channels = W.shape[0]
@@ -3328,7 +3327,6 @@ def test_qlinear_unpack(self, W, use_channelwise):
33283327
np.testing.assert_equal(
33293328
W_q.q_zero_point(), W_q_origin.q_zero_point())
33303329

3331-
33323330
@unittest.skipIf(IS_MACOS, "Known test failure on Mac.")
33333331
@unittest.skipIf(not BUILD_WITH_CAFFE2, "Test needs Caffe2")
33343332
class TestQuantizedEmbeddingOps(TestCase):
@@ -3578,8 +3576,6 @@ def get_reference_result(
35783576
include_last_offset=include_last_offset)
35793577
torch.testing.assert_close(reference_result, result, atol=atol, rtol=rtol)
35803578

3581-
3582-
35833579
""" Tests the correctness of the embedding_bag_8bit quantized operator """
35843580
@given(num_embeddings=st.integers(10, 100),
35853581
embedding_dim=st.integers(5, 50).filter(lambda x: x % 4 == 0),
@@ -3659,38 +3655,37 @@ def test_embedding_bag_2bit(self, num_embeddings,
36593655
sparsity=sparsity,
36603656
atol=1.0, rtol=1e-1)
36613657

3662-
""" Tests the correctness of the quantized embedding lookup operator """
3658+
""" Tests the correctness of the quantized 8 bit embedding lookup operator """
36633659
@given(num_embeddings=st.integers(10, 100),
36643660
embedding_dim=st.integers(5, 50).filter(lambda x: x % 4 == 0))
3665-
def test_embedding_byte(self, num_embeddings, embedding_dim):
3666-
quant_op = torch.ops.quantized.embedding_byte
3667-
prepack_op = torch.ops.quantized.embedding_bag_prepack
3668-
3669-
weights = torch.from_numpy((np.random.random_sample((
3670-
num_embeddings, embedding_dim)) + 1).astype(np.float32))
3671-
3672-
obs = PerChannelMinMaxObserver(dtype=torch.quint8, qscheme=torch.per_channel_affine_float_qparams, ch_axis=0)
3673-
obs(weights)
3674-
# Get the scale and zero point for the weight tensor
3675-
qparams = obs.calculate_qparams()
3676-
3677-
# Quantize the weights to 8bits
3678-
qweight = torch.quantize_per_channel(weights, qparams[0], qparams[1], axis=0, dtype=torch.quint8)
3679-
max_segments = 5
3680-
max_segment_length = 20
3681-
num_lengths = np.random.randint(1, max_segments + 1)
3682-
lengths = np.random.randint(1, max_segment_length + 1,
3683-
size=num_lengths).astype(np.int32)
3684-
num_indices = np.sum(lengths)
3685-
indices = torch.from_numpy(np.random.randint(
3686-
low=0, high=num_embeddings, size=num_indices, dtype=np.int64))
3687-
3688-
packed_weight = prepack_op(qweight)
3689-
qresult = quant_op(packed_weight, indices, pruned_weights=False)
3690-
3691-
ref = torch.embedding(weights, indices, padding_idx=-1, scale_grad_by_freq=False, sparse=False)
3692-
torch.testing.assert_close(ref, qresult, atol=0.005, rtol=1e-3)
3661+
def test_embedding(self, num_embeddings, embedding_dim):
3662+
dtypes = [torch.quint8, torch.quint4x2]
3663+
quant_ops = [torch.ops.quantized.embedding_byte, torch.ops.quantize.embedding_4bit]
3664+
for quant_op, dtype in zip(dtypes, quant_ops):
3665+
weights = torch.from_numpy((np.random.random_sample((
3666+
num_embeddings, embedding_dim)) + 1).astype(np.float32))
3667+
3668+
obs = PerChannelMinMaxObserver(dtype=dtype, qscheme=torch.per_channel_affine_float_qparams, ch_axis=0)
3669+
obs(weights)
3670+
# Get the scale and zero point for the weight tensor
3671+
qparams = obs.calculate_qparams()
36933672

3673+
# Quantize the weights to 8bits
3674+
qweight = torch.quantize_per_channel(weights, qparams[0], qparams[1], axis=0, dtype=torch.dtype)
3675+
max_segments = 5
3676+
max_segment_length = 20
3677+
num_lengths = np.random.randint(1, max_segments + 1)
3678+
lengths = np.random.randint(1, max_segment_length + 1,
3679+
size=num_lengths).astype(np.int32)
3680+
num_indices = np.sum(lengths)
3681+
indices = torch.from_numpy(np.random.randint(
3682+
low=0, high=num_embeddings, size=num_indices, dtype=np.int64))
3683+
3684+
packed_weight = prepack_op(qweight)
3685+
qresult = quant_op(packed_weight, indices, pruned_weights=False)
3686+
3687+
ref = torch.embedding(weights, indices, padding_idx=-1, scale_grad_by_freq=False, sparse=False)
3688+
torch.testing.assert_close(ref, qresult, atol=0.005, rtol=1e-3)
36943689

36953690
def test_embedding_2d_indices(self):
36963691
"""

0 commit comments

Comments
 (0)