Skip to content

Conversation

@bddppq
Copy link
Contributor

@bddppq bddppq commented Apr 12, 2019

Stack:
    :white_circle:  #19633 Add is_mkldnn to at::Tensor  💚
    :black_circle:  #19204 Add aten mkldnn conv2d operator  💚
    :white_circle:  #19205 Add aten mkldnn ops: relu, max_pool2d and avg_pool2d  💚
    :white_circle:  #19206 Add aten mkldnn batch_norm operator  💚
    :white_circle:  #19207 Add aten mkldnn add operator  💚
    :white_circle:  #19209 Add aten mkldnn view operator  💚
    :white_circle:  #19210 Add aten mkldnn linear operator  💚
    :white_circle:  #19648 Adjust resnext run script  💚

Pull Request resolved: #19204

Differential Revision: D14857513

Differential Revision: D14857513
Differential Version: 79209267
Differential Revision: D14857513
Differential Version: 79267360
@bddppq bddppq added the module: mkldnn Related to Intel IDEEP or oneDNN (a.k.a. mkldnn) integration label Apr 12, 2019
bddppq added 3 commits April 12, 2019 17:54
Differential Revision: D14857513
Differential Version: 79283994
Differential Revision: D14857513
Differential Version: 79296469
Differential Revision: D14857513
Differential Version: 79299277
input.scalar_type() == kFloat && // only on CPU Float Tensors
!is_dilated() && // doesn't support dilation
!transposed && // or transposed tensors
input.ndimension() == 4); // must be in NCHW format
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the comment, the original 'NCHW format' does not stand now.
BTW, conv with dilation and ndim=3/4/5 (aka. 1D, 2D, 3D) all supported by mkldnn. But we can do it with a separate PR.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These checks are for the conv on CPU tensor. Probably we just keep what has been done on the original path and put all future opts in the new MKL-DNN tensor path?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, that is for CPU tensor. Then forget what I said :-)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the story with the other checks? We can't use CPU tensor / MKLDNN algo if it's dilated / transposed / or the input dimension is wrong, but we can if it's an MKLDNN tensor?

And where's the check if it's not conv2d -- what happens if you try to call one of the other algos with an MKLDNN tensor?

out_channels=M,
kernel_size=3,
stride=2,
padding=1,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about exploring the spaces of kernel_size, stride and padding a little bit? From C2, we used the setting like below via hypothesis attributes:

    @given(stride=st.integers(1, 3),
           pad=st.integers(0, 3),
           kernel=st.integers(3, 5),
           size=st.integers(8, 10),
           input_channels=st.integers(1, 3),
           output_channels=st.integers(1, 5),
           batch_size=st.integers(1, 3),
           use_bias=st.booleans(),
           training_mode=st.booleans(),
           group=st.integers(1, 2),

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, we should probably start introducing hypothesis for cases like this. @gchanan - thoughts?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have any objections to that, although it would be great to start with the existing tests so it doesn't live in its own little world.

if (!input_is_mkldnn) {
input = input.contiguous();
}
auto weight = weight_r;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we provide MKLDNN tensor with contiguous()?
let is_contiguous() return true will do the job? In case we have a MKLDNN tensor, it should always be contiguous.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The question is how we can define sound semantics of contiguous for opaque tensors. It does not make sense to have is_contiguous return true when contiguous is not well defined.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can check TensorTypeId of course, i did not see any discomfortable to let MkldnnCPUTensorId return true for tensor.is_contiguous().

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a little worried of making contiguous work on MKLDNN as the semantics are not well-defined. Maybe add a wrapper function contiguous_or_mkldnn and call it only in this implementation?

Copy link
Collaborator

@mingfeima mingfeima Apr 17, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dzhulgakov Indeed the semantics indicates that mkldnn blocked layout should be non-contiguous while at the same time mkldnn function calls just need to know that whether the buffer is a physical contiguous chunk of memory (which is always yes for mkldnn tensor).

In case we don't inherit contiguous() for mkldnn tensor. Another simple method will be move all input.contiguous(), weight.contiguous(), bias.contiguous() from _convolution() to each underlying device implementations: a.k.a cudnn_convolution() , mkldnn_convolution() etc.
In this case, at the upper level from _convolution() will be unified (look cleaner) and inside mkldnn_convolution we are free to do something like:

// at::native::mkldnn_convolution(const Tensor& input_r, const Tensor& weight_r, ...)
auto input = is_input_mkldnn ? input_r : input_r.contiguous()
auto weight = is_weight_mkldnn ? weight_r : weight_r.contiguous()

and we can do this kind of check ONLY inside mkldnn function calls, so mkldnn tensor don't have to inherit contiguous()

output = at::mkldnn_convolution(input, weight, bias,
params.padding, params.stride, params.dilation, params.groups);
}
#endif
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also once we have contiguous(), we don't have to do if-else here depending on input_is_mkldnn, just one call to at::mkldnn_convolution is needed

const ConvParams& params, bool input_is_mkldnn) {
int64_t k = input.ndimension();
int64_t weight_dim = weight.ndimension();
std::vector<int64_t> weight_sizes(k);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe keep weight_sizes as ArrayRef and have a separate vector for sizes in mkldnn case and then point ArrayRef to that vector?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does ArrayRef extend the lifetime of a vector? (If not then we still can not avoid the creation of a vector?)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ArrayRef doesn't extend the lifetime of the vector.

if (!input_is_mkldnn) {
input = input.contiguous();
}
auto weight = weight_r;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a little worried of making contiguous work on MKLDNN as the semantics are not well-defined. Maybe add a wrapper function contiguous_or_mkldnn and call it only in this implementation?

out_channels=M,
kernel_size=3,
stride=2,
padding=1,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, we should probably start introducing hypothesis for cases like this. @gchanan - thoughts?

module._buffers[key] = fn(buf)

if isinstance(module, torch.nn.Conv2d):
module.weight.data = torch._C._nn.mkldnn_reorder_conv2d_weight(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mingfeima - so in general - is this reordering for inference only? if we do it in training, we'd need to make sure that all optimizers can handle different layout/dimensions and even auxiliary weights like for Adam work well.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we just do the reorder on the fly in the function? I'm unclear if the reorder is expensive -- it looks pretty simple, but not sure how it's actually implemented.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dzhulgakov Pre-reordering weights is still useful for training even though perf gain might not be that obvious, depending on the workloads. It is possible to leave existing optimizers intact to support the opaque tensor as long as it supports all the ops and factory functions used by the optimizers.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reorder on the fly (repeatedly) will cause 20% slower in conv (and 15% slower in the whole resnext).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is that slower than non-mkldnn or slower than mkldnn with optimization?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I'm trying to get at: if the "on the fly" packing is still faster than the dense path, then it makes sense to expose this op independent of pre-packing. But if it's not, we need to tie this into our greater weight pre-packing story.

import torch


def to_mkldnn(module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you just call _apply?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_apply doesn't pass the module itself to the callback, but we need it to check whether it's conv (and if yes we need to do extra conversion).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you just want to match children modules - there's .apply

module._buffers[key] = fn(buf)

if isinstance(module, torch.nn.Conv2d):
module.weight.data = torch._C._nn.mkldnn_reorder_conv2d_weight(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we just do the reorder on the fly in the function? I'm unclear if the reorder is expensive -- it looks pretty simple, but not sure how it's actually implemented.

out_channels=M,
kernel_size=3,
stride=2,
padding=1,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have any objections to that, although it would be great to start with the existing tests so it doesn't live in its own little world.

input.scalar_type() == kFloat && // only on CPU Float Tensors
!is_dilated() && // doesn't support dilation
!transposed && // or transposed tensors
input.ndimension() == 4); // must be in NCHW format
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the story with the other checks? We can't use CPU tensor / MKLDNN algo if it's dilated / transposed / or the input dimension is wrong, but we can if it's an MKLDNN tensor?

And where's the check if it's not conv2d -- what happens if you try to call one of the other algos with an MKLDNN tensor?

if t.is_floating_point():
return t.to_mkldnn()

for param in module._parameters.values():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it assumes that all ops work on mkldnn which might not be the case. Maybe we should have two versions: conservative that matches only some modules and full-on that handles all params

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm supporting partial conversion is hard I think, you will need to add to_mkldnn and to_dense at the boundaries.

import torch


def to_mkldnn(module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you just want to match children modules - there's .apply

bddppq added 2 commits April 16, 2019 15:11
Differential Revision: D14857513
Differential Version: 79689990
Differential Revision: D14857513
Differential Version: 79698398
Copy link
Collaborator

@dzhulgakov dzhulgakov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me. @gchanan - feel free to do the final pass

// mkldnn conv2d weights could have been re-ordered to 5d by
// mkldnn_reorder_conv2d_weight
if (weight.dim() == input.dim() + 1) {
AT_ASSERTM(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mingfeima - just curious - does mkldnn tensor for conv stored somewhere internally the original semantic sizes? I'm just worried that we hard-code logic here that might change later as mkldnn does potentially more optimizations

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope. mkldnn tensor stores only mkldnn::memory::dims.

There might be some risks here when generating output with new_with_itensor_mkldnn in case mkldnn::memory::dims is not align with at::sizes(). Since the output aten tensor would have sizes equals to mkldnn dims.

So far for CNN modules, i didn't see any problem, since output are all 4D tensors.
For RNN module, it's going to be an issue, since mkldnn dims and aten sizes are not aligned. (mkldnn rnn weight has 5d dims but aten weight has 2d size). Anyway, perhaps i can wrap this up inside ideep so that you don't have to worry about this.

Differential Revision: D14857513
Differential Version: 80386613
Differential Revision: D14857513
Differential Version: 80533695
@bddppq bddppq changed the base branch from master to export-D15053320 April 23, 2019 21:14
bddppq added 2 commits April 23, 2019 14:22
Differential Revision: D14857513
Differential Version: 80534910
Differential Revision: D14857513
Differential Version: 80541660
Differential Revision: D14857513
Differential Version: 80683619
Copy link
Collaborator

@dzhulgakov dzhulgakov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's good to go, yay!

// Mkldnn tensor has special non-public format for conv2d weights
// (dense_to_mkldnn only converts dense tensor to mkldnn tensor with
// public format). Ideep conv kernel will do implicit reorder if the
// weight is not already in this optimized format. By the time I'm
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for writing this comment!

By the time -> At the time of writing this note,

m._buffers[key] = t_fn(buf)

if isinstance(m, torch.nn.Conv2d):
m.weight.data = torch._C._nn.mkldnn_reorder_conv2d_weight(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we will probably need to move this call elsewhere in order to define "prepack for inference" interface in the future. Otherwise it's hard to serialize the mkldnn including the particular layout in the file. But it can be done in a separate PR when we get to more API pieces

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we will probably need to move this call elsewhere in order to define "prepack for inference" interface in the future. Otherwise it's hard to serialize the mkldnn including the particular layout in the file.

The pre-packing is also useful for training since it avoids extra reorders as well. In order to serialize the weights properly, can we always to_dense when the module is serialized?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For training more work would be needed - for example handling mkldnn tensors in optimizer

What I meant about inference is that we'd need a mechanism to bundle instructions to pack the weights after deserialization to mkldnn format (since as you said they'd be stored in as regular to_dense tensors). We have a prototype mechanism for it (https://github.com/pytorch/pytorch/blob/master/torch/jit/quantized.py#L34) but it needs to be improved. Most likely, we'd need to swap out Conv2d module with special MkldnnConv2d module that handles packing after loading

bddppq added 3 commits April 25, 2019 18:20
Differential Revision: D14857513
Differential Version: 80773729
Differential Revision: D14857513
Differential Version: 80785117
Differential Revision: D14857513
Differential Version: 80799724
zdevito pushed a commit to zdevito/ATen that referenced this pull request Apr 26, 2019
Summary: Pull Request resolved: pytorch/pytorch#19204

Reviewed By: dzhulgakov

Differential Revision: D14857513

fbshipit-source-id: 1172c9785e5a17a7d7360474551bdc7a511b3f2f
@facebook-github-bot
Copy link
Contributor

This pull request has been merged in 3445020.

zhangguanheng66 pushed a commit to zhangguanheng66/pytorch that referenced this pull request May 6, 2019
Summary: Pull Request resolved: pytorch#19204

Reviewed By: dzhulgakov

Differential Revision: D14857513

fbshipit-source-id: 1172c9785e5a17a7d7360474551bdc7a511b3f2f
@ezyang ezyang deleted the export-D14857513 branch May 30, 2019 15:56
!is_dilated() && // doesn't support dilation
!transposed && // or transposed tensors
input.ndimension() == 4; // must be in NCHW format
return (input.is_mkldnn()) || // input is mkldnn Tensor
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

module: mkldnn Related to Intel IDEEP or oneDNN (a.k.a. mkldnn) integration

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants