Skip to content

Conversation

@bddppq
Copy link
Contributor

@bddppq bddppq commented Apr 12, 2019

Stack:
    :white_circle:  #19633 Add is_mkldnn to at::Tensor  💚
    :white_circle:  #19204 Add aten mkldnn conv2d operator  💚
    :white_circle:  #19205 Add aten mkldnn ops: relu, max_pool2d and avg_pool2d  💚
    :black_circle:  #19206 Add aten mkldnn batch_norm operator  💚
    :white_circle:  #19207 Add aten mkldnn add operator  💚
    :white_circle:  #19209 Add aten mkldnn view operator  💚
    :white_circle:  #19210 Add aten mkldnn linear operator  💚
    :white_circle:  #19648 Adjust resnext run script  💚

Pull Request resolved: #19206

Differential Revision: D14887205

Differential Revision: D14887205
Differential Version: 79209395
bddppq added 2 commits April 12, 2019 21:15
Differential Revision: D14887205
Differential Version: 79296741
Differential Revision: D14887205
Differential Version: 79299276

ideep::tensor y;

if (train) {
Copy link
Collaborator

@XiaobingSuper XiaobingSuper Apr 15, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bddppq , I think there should check whether running status is used for training and infernce, i.e. running_mean and running_var are defined or not. Another suggestion is that mkldnn only support batchnorm2d and batchnorm3d, so you shoud make some checks when call mkldnn, perhaps you can see the code as reference. thanks!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would mkldnn throw with nice message that it's not supported? in this case the input is already mkldnn tensor so it's better to fail if something is not supported

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dzhulgakov I'm afraid not. :-( How about we add an assertion here to guarantee 2d or 3d batchnorm here?

bddppq added 3 commits April 16, 2019 15:11
Differential Revision: D14887205
Differential Version: 79689940
Differential Revision: D14887205
Differential Version: 79698401
Differential Revision: D14887205
Differential Version: 79729545

ideep::tensor y;

if (train) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would mkldnn throw with nice message that it's not supported? in this case the input is already mkldnn tensor so it's better to fail if something is not supported

bddppq added 2 commits April 22, 2019 13:23
Differential Revision: D14887205
Differential Version: 80386617
Differential Revision: D14887205
Differential Version: 80533698
bddppq added 3 commits April 23, 2019 14:22
Differential Revision: D14887205
Differential Version: 80534908
Differential Revision: D14887205
Differential Version: 80541662
Differential Revision: D14887205
Differential Version: 80560075
bddppq added 2 commits April 23, 2019 21:42
Differential Revision: D14887205
Differential Version: 80580146
Differential Revision: D14887205
Differential Version: 80683617

ideep::tensor y;

if (train) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This training support is incomplete. Derivatives.yaml calls native_batch_norm_backward and that one doesn't know what to do with mkldnn tensors. Thus I'd say either assert that train=false or implement native_batch_norm_backward for consistency

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes senses will add an assert

bddppq added 4 commits April 25, 2019 18:20
Differential Revision: D14887205
Differential Version: 80773735
Differential Revision: D14887205
Differential Version: 80780049
Differential Revision: D14887205
Differential Version: 80785120
Differential Revision: D14887205
Differential Version: 80799729
zdevito pushed a commit to zdevito/ATen that referenced this pull request Apr 26, 2019
Summary: Pull Request resolved: pytorch/pytorch#19206

Reviewed By: dzhulgakov

Differential Revision: D14887205

fbshipit-source-id: ea00c9e3205c449d08ab29535309164f951aab95
@facebook-github-bot
Copy link
Contributor

This pull request has been merged in fb53c18.

zhangguanheng66 pushed a commit to zhangguanheng66/pytorch that referenced this pull request May 6, 2019
Summary: Pull Request resolved: pytorch#19206

Reviewed By: dzhulgakov

Differential Revision: D14887205

fbshipit-source-id: ea00c9e3205c449d08ab29535309164f951aab95
@ezyang ezyang deleted the export-D14887205 branch May 30, 2019 15:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

module: mkldnn Related to Intel IDEEP or oneDNN (a.k.a. mkldnn) integration

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants