Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions aten/src/ATen/native/RNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

#include <ATen/ATen.h>
#include <ATen/NativeFunctions.h>
#include <ATen/cpp_custom_type_hack.h>
#include <ATen/native/quantized/cpu/fbgemm_utils.h>

#include <ATen/native/c10_utils.h>

Expand Down Expand Up @@ -278,12 +280,23 @@ static std::vector<QuantizedCellParamsDynamic> gather_quantized_params_dynamic(
static at::Tensor undefined;
std::vector<QuantizedCellParamsDynamic> result;
TORCH_CHECK(
params.size() % 4 == 0,
params.size() % 2 == 0,
"got an incorrect number of quantized RNN parameters");
for (size_t i = 0; i < params.size(); i += 4) {
result.emplace_back(params[i], params[i + 1], params[i + 2], params[i + 3]);
// PackedLinearWeight is only defined when USE_FBGEMM is defined
#ifdef USE_FBGEMM
for (size_t i = 0; i < params.size(); i += 2) {
auto& packed_struct_ih =
cpp_custom_type_hack::cast<PackedLinearWeight>(params[i]);
auto& packed_struct_hh =
cpp_custom_type_hack::cast<PackedLinearWeight>(params[i + 1]);
auto bias_ih = packed_struct_ih.bias.value_or(undefined);
auto bias_hh = packed_struct_hh.bias.value_or(undefined);
result.emplace_back(params[i], params[i + 1], bias_ih, bias_hh);
}
return result;
#else // USE_FBGEMM
TORCH_INTERNAL_ASSERT(false, "Tried to use quantized RNN wihtout FBGEMM!")
#endif // USE_FBGEMM
}

static std::vector<QuantizedCellParamsFP16> gather_quantized_params_fp16(
Expand Down
30 changes: 28 additions & 2 deletions test/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,8 +519,34 @@ def test_quantized_rnn(self):

torch.testing.assert_allclose(output_int8, ref_out)
self.assertEqual(output_int8, ref_out)
for out, ref in zip(final_hiddens_int8, ref_hid):
torch.testing.assert_allclose(out, ref)
for out_val, ref_val in zip(final_hiddens_int8, ref_hid):
torch.testing.assert_allclose(out_val, ref_val)

class ScriptWrapper(torch.nn.Module):
def __init__(self, cell):
super(ScriptWrapper, self).__init__()
self.cell = cell

def forward(self, x, hiddens):
# type: (torch.Tensor, Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
return self.cell(x, hiddens)

# TODO: TorchScript overloads don't work without this wrapper
cell_script = torch.jit.script(ScriptWrapper(cell_int8))
out_script, hid_script = cell_script(x, hiddens)
self.assertEqual(len(out_script), len(ref_out))
for out_val, ref_val in zip(out_script, ref_out):
torch.testing.assert_allclose(out_val, ref_val)

# Test save/load
b = io.BytesIO()
torch.jit.save(cell_script, b)
b.seek(0)
loaded = torch.jit.load(b)
out_loaded, hid_loaded = loaded(x, hiddens)
for loaded_val, ref_val in zip(out_loaded, ref_out):
torch.testing.assert_allclose(loaded_val, ref_val)


@unittest.skipIf(
not torch.fbgemm_is_cpu_supported(),
Expand Down
1 change: 1 addition & 0 deletions torch/csrc/jit/script/python_sugared_value.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ std::shared_ptr<SugaredValue> OverloadedMethodValue::call(
for (const std::string& method_name : method_names_) {
auto cls = module_->type()->expect<ClassType>();
const auto fn = cls->getMethod(method_name);
TORCH_INTERNAL_ASSERT(fn, "Expected class to have method ", method_name);
auto match = tryMatchSchema(
fn->getSchema(),
loc,
Expand Down
129 changes: 64 additions & 65 deletions torch/nn/quantized/dynamic/modules/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from torch import Tensor # noqa: F401
from torch.nn import _VF
from torch._jit_internal import Tuple, Optional, List # noqa: F401
from torch._jit_internal import _parameter_list
from torch.nn.utils.rnn import PackedSequence
import numbers

Expand All @@ -19,10 +18,6 @@ class RNNBase(torch.nn.Module):

_FLOAT_MODULE = nn.RNNBase

__constants__ = ['mode', 'input_size', 'hidden_size', 'num_layers', 'bias',
'batch_first', 'dropout', 'bidirectional', '_packed_weights',
'_quantized_weights']

def __init__(self, mode, input_size, hidden_size,
num_layers=1, bias=True, batch_first=False,
dropout=0., bidirectional=False):
Expand Down Expand Up @@ -54,31 +49,24 @@ def __init__(self, mode, input_size, hidden_size,
else:
raise ValueError("Unrecognized RNN mode: " + mode)

self._all_weights = []

packed_weights = []
quantized_weights = []
self._all_weight_names = []
self._all_weight_values = []

for layer in range(num_layers):
for direction in range(num_directions):
layer_input_size = input_size if layer == 0 else hidden_size * num_directions

def process_weights(ihhh, layer, suffix, qweight, bias):
weight_name = 'weight_{}_l{}{}'.format(ihhh, layer, suffix)
bias_name = 'bias_{}_l{}{}'.format(ihhh, layer, suffix)

# for each layer, for each direction we need to quantize and pack
# weights and pack parameters in this order:
#
# w_ih, w_hh, b_ih, b_hh
# w_ih, w_hh
packed_weight = \
torch.ops.quantized.linear_prepack(qweight)
params = [packed_weight, bias]
pos_names = ['w', 'b']
torch.ops.quantized.linear_prepack(qweight, bias)
params = [packed_weight]
pos_names = ['w']
ret_name = ['{}_{}_l{}{}'.format(
name, ihhh, layer, suffix) for name in pos_names]
quantized_weights.append(qweight)
packed_weights.append(ret_name[0])
return params, ret_name

w_ih = torch._empty_affine_quantized(
Expand All @@ -99,11 +87,11 @@ def process_weights(ihhh, layer, suffix, qweight, bias):
'hh', layer, suffix, w_hh, b_hh)

for (ih, ih_name), (hh, hh_name) in zip(zip(ih_params, ih_param_names), zip(hh_params, hh_param_names)):
self.register_buffer(ih_name, torch.tensor(
ih) if not isinstance(ih, torch.Tensor) else ih)
self.register_buffer(hh_name, torch.tensor(
hh) if not isinstance(hh, torch.Tensor) else hh)
self._all_weights.extend([ih_name, hh_name])

self._all_weight_names.extend([ih_name, hh_name])
self._all_weight_values.extend([ih, hh])



def check_input(self, input, batch_sizes):
# type: (Tensor, Optional[Tensor]) -> None
Expand Down Expand Up @@ -148,30 +136,47 @@ def permute_hidden(self, hx, permutation):
return hx
return apply_permutation(hx, permutation)

@property
def all_weights(self):
return [getattr(self, weight) for weight in self._all_weights]

def _get_all_weights_names(self):
return [weight for weight in self._all_weights]

@_parameter_list(_get_all_weights_names)
def _get_all_weights(self):
return self.all_weights

def _get_packed_weights_names(self):
return self._packed_weights

@_parameter_list(_get_packed_weights_names)
def _get_packed_weights(self):
return [getattr(self, name) for name in self._packed_weights]

def _get_quantized_weights_names(self):
return self._quantized_weights

@_parameter_list(_get_quantized_weights_names)
def _get_quantized_weights(self):
return [getattr(self, name) for name in self._quantized_weights]
@torch.jit.export
def __getstate__(self):
vals = (
self.mode,
self.input_size,
self.hidden_size,
self.num_layers,
self.bias,
self.batch_first,
self.dropout,
self.bidirectional,
self._all_weight_names,
self.__overloads__,
self.training,
)

dynamic_vals = torch.jit.annotate(List[Tuple[torch.Tensor, Optional[torch.Tensor]]],
[])

for i in range(len(self._all_weight_names)):
dynamic_vals.append(torch.ops.quantized.linear_unpack(self._all_weight_values[i]))
return vals, dynamic_vals

@torch.jit.export
def __setstate__(self, state):
vals, dynamic_vals = state
self.mode = vals[0]
self.input_size = vals[1]
self.hidden_size = vals[2]
self.num_layers = vals[3]
self.bias = vals[4]
self.batch_first = vals[5]
self.dropout = vals[6]
self.bidirectional = vals[7]
self._all_weight_names = vals[8]
self.__overloads__ = vals[9]
self.training = vals[10]

self._all_weight_values = []
for i in range(len(self._all_weight_names)):
self._all_weight_values.append(torch.ops.quantized.linear_prepack(*dynamic_vals[i]))

@classmethod
def from_float(cls, mod):
Expand Down Expand Up @@ -200,9 +205,8 @@ def from_float(cls, mod):
if qRNNBase.mode != 'LSTM':
raise RuntimeError('Only LSTM is supported for QuantizedRNN')

qRNNBase._all_weights = []
packed_weights = []
quantized_weights = []
qRNNBase._all_weight_names = []
qRNNBase._all_weight_values = []
for layer in range(qRNNBase.num_layers):
for direction in range(num_directions):
layer_input_size = qRNNBase.input_size if layer == 0 else qRNNBase.hidden_size * num_directions
Expand All @@ -216,36 +220,28 @@ def process_weights(ihhh, layer, suffix):
# for each layer, for each direction we need to quantize and pack
# weights and pack parameters in this order:
#
# w_ih, w_hh, b_ih, b_hh
# w_ih, w_hh
weight_observer(weight)
wt_scale, wt_zp = weight_observer.calculate_qparams()
qweight = torch.quantize_linear(
weight.float(), float(wt_scale), int(wt_zp), torch.qint8)
packed_weight = \
torch.ops.quantized.linear_prepack(qweight, bias)

params = [packed_weight, bias]
pos_names = ['w', 'b']
params = [packed_weight]
pos_names = ['w']
ret_name = ['{}_{}_l{}{}'.format(
name, ihhh, layer, suffix) for name in pos_names]
quantized_weights.append(qweight)
packed_weights.append(ret_name[0])
return params, ret_name

suffix = '_reverse' if direction == 1 else ''
ih_params, ih_param_names = process_weights('ih', layer, suffix)
hh_params, hh_param_names = process_weights('hh', layer, suffix)

for (ih, ih_name), (hh, hh_name) in zip(zip(ih_params, ih_param_names), zip(hh_params, hh_param_names)):
qRNNBase.register_buffer(ih_name, torch.tensor(
ih) if not isinstance(ih, torch.Tensor) else ih)
qRNNBase.register_buffer(hh_name, torch.tensor(
hh) if not isinstance(hh, torch.Tensor) else hh)
qRNNBase._all_weights.extend([ih_name, hh_name])
qRNNBase._all_weight_names.extend([ih_name, hh_name])
qRNNBase._all_weight_values.extend([ih, hh])

qRNNBase._packed_weights = packed_weights
# DO WE NEED _quantized_weights? @jianyuh: will remove _quantized_weight as now we support the fbgemm_linear_unpack function
qRNNBase._quantized_weights = quantized_weights

return qRNNBase

Expand Down Expand Up @@ -275,14 +271,15 @@ def forward_impl(self, input, hx, batch_sizes, max_batch_size, sorted_indices):
self.check_forward_args(input, hx, batch_sizes)
assert batch_sizes is None

result = _VF.quantized_lstm(input, hx, self._get_all_weights(), self.bias, self.num_layers,
result = _VF.quantized_lstm(input, hx, self._all_weight_values, self.bias, self.num_layers,
float(self.dropout), self.training, self.bidirectional,
self.batch_first, dtype=torch.int8, use_dynamic=True)
output = result[0]
hidden = result[1:]

return output, hidden

@torch.jit.export
def forward_tensor(self, input, hx=None):
# type: (Tensor, Optional[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]
batch_sizes = None
Expand All @@ -295,6 +292,7 @@ def forward_tensor(self, input, hx=None):

return output, self.permute_hidden(hidden, unsorted_indices)

@torch.jit.export
def forward_packed(self, input, hx=None):
# type: (PackedSequence, Optional[Tuple[Tensor, Tensor]]) -> Tuple[PackedSequence, Tuple[Tensor, Tensor]] # noqa
input, batch_sizes, sorted_indices, unsorted_indices = input
Expand All @@ -315,7 +313,7 @@ def permute_hidden(self, hx, permutation):
return apply_permutation(hx[0], permutation), apply_permutation(hx[1], permutation)

def check_forward_args(self, input, hidden, batch_sizes):
# type : (Tensor, Tuple[Tensor, Tensor], Optional[Tensor])->None
# type: (Tensor, Tuple[Tensor, Tensor], Optional[Tensor])->None
self.check_input(input, batch_sizes)
expected_hidden_size = self.get_expected_hidden_size(input, batch_sizes)

Expand All @@ -324,6 +322,7 @@ def check_forward_args(self, input, hidden, batch_sizes):
self.check_hidden_size(hidden[1], expected_hidden_size,
'Expected hidden[1] size {}, got {}')

@torch.jit.ignore
def forward(self, input, hx=None):
if isinstance(input, PackedSequence):
return self.forward_packed(input, hx)
Expand Down