Override max_pool2d as autograd function.#2236
Conversation
11187de to
e9e8e1b
Compare
torch_xla/csrc/aten_xla_type.cpp
Outdated
| torch::IntArrayRef padding, | ||
| torch::IntArrayRef dilation, bool ceil_mode) { | ||
| ctx->saved_data["kernel_size"] = kernel_size; | ||
| ctx->saved_data["stride"] = stride; |
There was a problem hiding this comment.
I dunno how the saved_data map works, but array-ref are ... refs, so unless the saved_data map has special handling to make copies, will those be valid in the backward?
There was a problem hiding this comment.
Arrayref is converted to IValue in the saved map. I'd assume this is okay since 1) it passed test 2) torchvision is using it the same way :D https://github.com/pytorch/vision/blob/e89c4c0198fd8cbd11344564162c68771524b2d7/torchvision/csrc/PSROIPool.h#L83
|
(CI test failure will be fixed by pytorch/pytorch#40265 [pending review] |
e9e8e1b to
916ef39
Compare
|
|
||
| ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); | ||
| ExpectCounterChanged("xla::max_pool3d_with_indices", | ||
| ExpectCounterChanged("xla::max_pool3d", |
There was a problem hiding this comment.
Remind me to make a pass on this file and move those test before returning from function 😄
No need to do it here.
916ef39 to
8b124a2
Compare
8b124a2 to
f400925
Compare
No description provided.