Skip to content

The JIT sometimes does not repeat values in the output_size argument. #20215

@skrah

Description

@skrah

Background

This issue came up during the port of adaptive_max_pool2d() to ATen. The relevant function signature in native_functions.yaml is:

func: adaptive_max_pool2d(Tensor self, int[2] output_size) -> (Tensor, Tensor)

The issue

When output_size is a single int on the Python level (square image H x H), adaptive_max_pool2d() receives an output_size argument with output_size.size()==1 instead of 2.

Currently adaptive_max_pool2d() has a workaround that also accepts output_size.size()==1 and repeats H inside the function.

How to reproduce

The cause is difficult to isolate since the tests are generated and quite complex. I think also a module export/import is involved.

  1. Apply this diff:
diff --git a/aten/src/ATen/native/AdaptiveMaxPooling2d.cpp b/aten/src/ATen/native/AdaptiveMaxPooling2d.cpp
index b3b77c5..ec16c5e 100644
--- a/aten/src/ATen/native/AdaptiveMaxPooling2d.cpp
+++ b/aten/src/ATen/native/AdaptiveMaxPooling2d.cpp
@@ -322,6 +322,7 @@ std::tuple<Tensor, Tensor> adaptive_max_pool2d_cpu(
 {
   Tensor output = at::empty({0}, input.options());
   Tensor indices = at::empty({0}, input.options().dtype(kLong));
+  assert(output_size.size() == 2);
   adaptive_max_pool2d_out_cpu_template(
     output,
     indices,
  1. Run the relevant test case:

python3 -m pytest -v test_jit.py::TestJitGeneratedModule::test_nn_AdaptiveMaxPool2d_single

Backtrace

(gdb) f 9
#9  0x00007fffcdeb2127 in torch::jit::(anonymous namespace)::<lambda(torch::jit::Stack&)>::operator()(torch::jit::Stack &) const (__closure=0x44ba5580, 
    stack=std::vector of length 2, capacity 2 = {...}) at /home/stefan/pytorch/torch/csrc/jit/generated/register_aten_ops_1.cpp:2948
2948              );
(gdb) f 8
#8  0x00007fffcde9aa4d in at::adaptive_max_pool2d (self=..., output_size=...) at /home/stefan/pytorch/build/aten/src/ATen/Functions.h:5728
5728        return detail::infer_type(self).adaptive_max_pool2d(self, output_size);
(gdb) p output_size.size()
$1 = 1

Metadata

Metadata

Assignees

Labels

module: nnRelated to torch.nnoncall: jitAdd this issue/PR to JIT oncall triage queue

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions