Skip to content

Conversation

@jerryzh168
Copy link
Contributor

@jerryzh168 jerryzh168 commented Mar 20, 2019

Stack:
    :white_circle:  #18765 [clang-format] For some files that are touched by the QTensor diff  💚
    :black_circle:  #18230 [pt1][quant] QTensor  💚

Implementing minimum qtensor API to unblock other workstreams in quantization

Changes:

  • Added Quantizer which represents different quantization schemes
  • Added qint8 as a data type for QTensor
  • Added a new ScalarType QInt8
  • Added QTensorImpl for QTensor
  • Added following user facing APIs
    • quantize_linear(scale, zero_point)
    • dequantize()
    • q_scale()
    • q_zero_point()

Differential Revision: D14524641

Differential Revision: D14524641
Differential Version: 76207283
@bddppq bddppq changed the title [wip][pt1][quant] QTensor [pt1][quant] QTensor Mar 20, 2019
Differential Revision: D14524641
Differential Version: 76232282
Differential Revision: D14524641
Differential Version: 76268727
Differential Revision: D14524641
Differential Version: 76274055
Differential Revision: D14524641
Differential Version: 76276578
@bddppq
Copy link
Contributor

bddppq commented Mar 21, 2019

Need to add doc for quantize and dequantize

03:34:21 ======================================================================
03:34:21 FAIL: test_doc (test_torch.TestTorch)
03:34:21 ----------------------------------------------------------------------
03:34:21 Traceback (most recent call last):
03:34:21   File "/var/lib/jenkins/workspace/test/test_torch.py", line 229, in test_doc
03:34:21     'sparse_resize_and_clear_',
03:34:21   File "/var/lib/jenkins/workspace/test/test_torch.py", line 190, in test_namespace
03:34:21     self.assertTrue(has_doc, '{} is missing documentation'.format(full_name))
03:34:21 AssertionError: False is not true : Tensor.dequantize is missing documentation
03:34:21 

Differential Revision: D14524641
Differential Version: 76396588
Copy link
Contributor

@gchanan gchanan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this looks like a nice start! I didn't look into the quantization details, I was mainly looking at the ATen/c10 structure. The main things I noticed are:

  1. We should not have quantized device. As I mention below, I believe we need to change how we do TensorTypeId / Backend lookup to take the dtype into account. You might want to split this out into a separate PR, because it's useful anyway and doesn't need to be held up by the rest.
  2. The AT_FORALL macros are getting to complicated (not your fault). I'll take this as an action item to see what I can do.
  3. There's some code around exposing quantizer information to python which doesn't seem like a good idea yet.

Differential Revision: D14524641
Differential Version: 76418300
Differential Revision: D14524641
Differential Version: 76514776
Differential Revision: D14524641
Differential Version: 76527628
@jerryzh168
Copy link
Contributor Author

@bddppq Is this expected?
01:18:24 [==========] Running 4 tests from 1 test case.
01:18:24 [----------] Global test environment set-up.
01:18:24 [----------] 4 tests from MathROCBLASTest
01:18:24 [ RUN ] MathROCBLASTest.GemmNoTransNoTrans
02:31:11 Build timed out (after 90 minutes). Marking the build as failed.

Differential Revision: D14524641
Differential Version: 76744384
#include <memory>

// TODO: move to c10 namespace after we
// unified caffe2::Tensor and at::Tensor
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this TODO is stale

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why? we don't plan merge them anymore?

Differential Revision: D14524641
Differential Version: 77428769
facebook-github-bot pushed a commit that referenced this pull request Mar 30, 2019
Summary:
Problem:
```cpp
// This function expects a `Variable` as input
inline PyObject* wrap(at::Tensor tensor) {
  return THPVariable_Wrap(Variable(std::move(tensor)));
}

inline PyObject* wrap(at::Scalar scalar) {
  // This function calls `wrap(at::Tensor tensor)` (the function above), but since
  // `scalar_to_tensor(...)` returns a `Tensor` and not a `Variable`, the call to
  // `wrap(at::Tensor tensor)` will fail with "Tensor that was converted to Variable
  // was not actually a Variable", which is not what we want.
  return wrap(scalar_to_tensor(scalar));
}
```

The right fix is to call `make_variable(...)` with the tensor returned from `scalar_to_tensor(scalar)`.

This unblocks #18230 as it is the only patch that hits this code path now. All other native functions that return Scalar (such as `item()` or `_local_scalar_dense()`) either has custom-defined implementation that doesn't go through this path, or is not exposed to Python at all.
Pull Request resolved: #18632

Differential Revision: D14689293

Pulled By: yf225

fbshipit-source-id: be7ba5d3de83a69533a2997de97ad92989ff78ee
@yf225
Copy link
Contributor

yf225 commented Apr 1, 2019

@jerryzh168 #18632 is merged, feel free to pull master into your branch again :)

Copy link
Collaborator

@dzhulgakov dzhulgakov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, modulo some non-blocking comments (please address those!)

struct CAFFE2_API PerChannelSymmetricQuantizer: public SymmetricQuantizer {
PerChannelSymmetricQuantizer() {}
PerChannelSymmetricQuantizer(std::vector<float> scales, std::vector<int64_t> axis): SymmetricQuantizer(kPerChannelSymmetric), scales_(scales), axis_(axis) {
AT_ASSERT(axis_.size() == 1);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AT_CHECK and add descriptive error message

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AT_CHECK or AT_ASSERTM? I saw a comment:
// TODO: merge AT_CHECK with AT_ASSERTM. CHECK in fbcode means strict failure if
// not met.

@jerryzh168
Copy link
Contributor Author

@dzhulgakov For qint type conversion, I added a conversion to uint8_t, since it is the underlying type, and I think it is better than converting to uint32_t.

Differential Revision: D14524641
Differential Version: 77682011
Differential Revision: D14524641
Differential Version: 77702132
Copy link

@ZolotukhinM ZolotukhinM left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the work! A couple of minor nitpicks inline.

// setters/getters for QTensorImpl fields; otherwise, you should use
// the low level setters/getters that were implemented using this.
// This may be called repeatedly, so make sure it's pretty cheap.
CAFFE2_API QTensorImpl* get_qtensorimpl(const QTensor& self);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make sense to make this a private method of QTensor (and adding all users to friends)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't have a separate QTensor class, it's the same as Tensor..

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, please ignore this then :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if it's a good idea to make QTensorImpl as a friend class of Tensor, which might contain a intrusive_ptr to QTensorImpl


// define the scalar.to<int64_t>() specializations
template<typename T>
template <typename T>

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice if you could commit clang-format changes separately. This is a big important patch, and it would be nice to not dilute it with whitespace changes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I didn't expect there were so many whitespace changes, is there a way to undo?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it should be fine in this case, the added files are more important than changed files for this diff.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could clang-format this entire file in a separate PR and rebase this PR on top, or you can git reset HEAD^ and then git add -p and git commit to select only meaningful parts (I hope clang-format hook would not fire on parts you don't touch).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, just updated

Differential Revision: D14524641
Differential Version: 77786239
Differential Revision: D14524641
Differential Version: 77793094
Differential Revision: D14524641
Differential Version: 77816109
jerryzh168 added a commit to jerryzh168/pytorch that referenced this pull request Apr 2, 2019
Summary:
Pull Request resolved: pytorch#18230

Implementing minimum qtensor API to unblock other workstreams in quantization

Changes:
- Added Quantizer which represents different quantization schemes
- Added qint8 as a data type for QTensor
- Added a new ScalarType QInt8
- Added QTensorImpl for QTensor
- Added following user facing APIs
  - quantize_linear(scale, zero_point)
  - dequantize()
  - q_scale()
  - q_zero_point()

Reviewed By: dzhulgakov

Differential Revision: D14524641

fbshipit-source-id: 4a4e4bb6dd485cffdd3b0e064cadb18b746373b3
Differential Revision: D14524641
Differential Version: 77823420
@jerryzh168 jerryzh168 changed the base branch from master to export-D14733442 April 2, 2019 21:10
Differential Revision: D14524641
Differential Version: 77871069
Differential Revision: D14524641
Differential Version: 77958995
@facebook-github-bot
Copy link
Contributor

This pull request has been merged in dfcd7b0.

- func: quantize_linear(Tensor self, float scale, int zero_point) -> Tensor
matches_jit_signature: True
variants: function, method
requires_tensor: True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jerryzh168 @gchanan - why do we need/have requires_tensor: True here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is going to be removed in next pr.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added when I was trying to make the python API work since it was complaining something related to Variable.

- func: dequantize(Tensor self) -> Tensor
matches_jit_signature: True
variants: function, method
requires_tensor: True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why don't we put

dispatch:
  QuantizedQInt8: dequantize

here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is in next pr

@yinghai
Copy link
Contributor

yinghai commented Apr 5, 2019

How do we express channel wise quantization, by supplying a different quantizer?

@dzhulgakov
Copy link
Collaborator

@yinghai - yes

* data types in the future.
*/
struct alignas(1) qint8 {
uint8_t val_;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the uint8_t assume the quantization is asymmetric? For symmetric quantization, the value could be either unsigned or signed. How do we know? Will you put a flag in the symmetric quantizer for it?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An alternative is to share the same class for symmetric and asymmetric quantizer and decide whether the quantization is symmetric via the zero point value, i.e. 0 for unsigned symmetric and 128 for signed symmetric.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can have new data types for symmetric quantization

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can add these support later if there is a need, right now we only have per tensor affine quantization with unsigned int8 as data type.

@yinghai yinghai deleted the export-D14524641 branch April 16, 2019 16:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.