-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Adding function to convert Module to channels last #28991
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding function to convert Module to channels last #28991
Conversation
[ghstack-poisoned]
…nnels last" [ghstack-poisoned]
[ghstack-poisoned]
…ls last" [ghstack-poisoned]
…o channels last" [ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
…last" [ghstack-poisoned]
Differential Revision: [D18430810](https://our.internmc.facebook.com/intern/diff/D18430810) [ghstack-poisoned]
… channels last" Differential Revision: [D18430810](https://our.internmc.facebook.com/intern/diff/D18430810) [ghstack-poisoned]
Differential Revision: [D18430810](https://our.internmc.facebook.com/intern/diff/D18430810) [ghstack-poisoned]
Differential Revision: [D18430810](https://our.internmc.facebook.com/intern/diff/D18430810) [ghstack-poisoned]
Differential Revision: [D18430810](https://our.internmc.facebook.com/intern/diff/D18430810) [ghstack-poisoned]
… channels last" Differential Revision: [D18430810](https://our.internmc.facebook.com/intern/diff/D18430810) [ghstack-poisoned]
Differential Revision: [D18430810](https://our.internmc.facebook.com/intern/diff/D18430810) [ghstack-poisoned]
soumith
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm from my request. cc: @gchanan as the other reviewer who requested changes,
…odule to channels last" Differential Revision: [D18430810](https://our.internmc.facebook.com/intern/diff/D18430810) [ghstack-poisoned]
…s last" Differential Revision: [D18430810](https://our.internmc.facebook.com/intern/diff/D18430810) [ghstack-poisoned]
Differential Revision: [D18430810](https://our.internmc.facebook.com/intern/diff/D18430810) [ghstack-poisoned]
CircleCI build failures summaryAs of commit ea3ed3d:
Detailed failure analysisOne may explore the probable reasons each build failed interactively on the Dr. CI website. 3 upstream failures recognized by patterns:These builds matched patterns, but were probably caused by upstream breakages:
This comment was automatically generated by Dr. CI. Please report bugs/suggestions on the GitHub issue tracker. This comment has been revised 5 times. |
|
|
||
| def convert(t): | ||
| if convert_to_format is not None and t.dim() == 4: | ||
| return t.to(device, dtype if t.is_floating_point() else None, non_blocking, memory_format=convert_to_format) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this doesn't match the documentation (which says the only case for to with memory-format is the 1-arg case).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pardon, but I read 'This can be called as' as: here examples of calls, but they are not limited to this options.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that's not my reading of it, although I can see why you read it that way.
In particular, before memory_format was introduced, it corresponded exactly to the function signatures:
pytorch/torch/csrc/autograd/utils/python_arg_parsing.h
Lines 15 to 17 in 66f2bba
| "to(Device device=None, ScalarType dtype=None, bool non_blocking=False, bool copy=False, *, MemoryFormat? memory_format=None)", | |
| "to(ScalarType dtype, bool non_blocking=False, bool copy=False, *, MemoryFormat? memory_format=None)", | |
| "to(Tensor tensor, bool non_blocking=False, bool copy=False, *, MemoryFormat? memory_format=None)", |
(remove copy because it's not supported and remove memory_format because we are considering the case before memory_format was introduced).
So, the introduction of memory_format changed this from "these are the supported signatures" to "these are some examples of supported signatures". I think the former is more useful and we should change it back.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And actually, the example you added (.. function:: to(memory_format=torch.channels_last)) doesn't work because the parsing code hasn't been updated, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO we should do the following:
- Add a memory_format only-overload to the python parsing.
- List the valid calls, i.e.:
.. function:: to(device=None, dtype=None, non_blocking=False, memory_format=None)
.. function:: to(dtype, non_blocking=False, memory_format=None)
.. function:: to(tensor, non_blocking=False, memory_format=None)
.. function:: to(memory_format)
(or similar).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had parsing updated https://github.com/pytorch/pytorch/pull/28991/files#diff-
7a05cfd8eb442889dffd6c3d2e4d0ddcR24
Will update inline and html docs as follow-up PR
| the floating point parameters and buffers in this module | ||
| tensor (torch.Tensor): Tensor whose dtype and device are the desired | ||
| dtype and device for all parameters and buffers in this module | ||
| memory_format (:class:`torch.memory_format`): the desired memory |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
are you planning to have a reference section for memory_format that you can point to? This description isn't full enough for the long term. (e.g. why only 4D parameters/buffers? -- it's not clear)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes as soon as we land new defaults for .clone .to *_like ops I will work on updating docs.
|
@VitalyFedyunin merged this pull request in 66f2bba. |
ghstack-source-id: f650f3b Pull Request resolved: pytorch/pytorch#28991
Summary: Pull Request resolved: pytorch#28991 Test Plan: Imported from OSS Differential Revision: D18430810 Pulled By: VitalyFedyunin fbshipit-source-id: 0693d4e31fc6f9831722c29fc83517f16ddfc028
|
This PR adds an API I think this is making an assumption that "any 4D parameters in the module needs a conversion to NHWC layout, if user wants to use nvidia's NHWC kernels". From what I can see this assumption is limiting in many ways:
It may be better if such an API:
|
|
@ppwwyyxx , I totally agree. |
Stack from ghstack:
zeros,ones,full#31131 Add memory_format support tozeros,ones,full(still need tests)Differential Revision: D18430810