Skip to content

[torch.onnx] support torch.nn.functional.grid_sample#76159

Closed
crcrpar wants to merge 17 commits intopytorch:masterfrom
crcrpar:onnx-opset16-gridsample
Closed

[torch.onnx] support torch.nn.functional.grid_sample#76159
crcrpar wants to merge 17 commits intopytorch:masterfrom
crcrpar:onnx-opset16-gridsample

Conversation

@crcrpar
Copy link
Copy Markdown
Collaborator

@crcrpar crcrpar commented Apr 21, 2022

summary

  • Adds F.grid_sample support
  • Adds a test case

Fixes #27212

@facebook-github-bot
Copy link
Copy Markdown
Contributor

facebook-github-bot commented Apr 21, 2022

🔗 Helpful links

💊 CI failures summary and remediations

As of commit 71f8806 (more details on the Dr. CI page):

Expand to see more

💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@facebook-github-bot facebook-github-bot added the oncall: jit Add this issue/PR to JIT oncall triage queue label Apr 21, 2022
@albanD albanD added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Apr 25, 2022
@BowenBao BowenBao self-assigned this Apr 26, 2022
@BowenBao BowenBao added module: onnx Related to torch.onnx release notes: onnx torch.onnx related changes that should show up in the release notes topic: improvements topic category labels Apr 26, 2022
@BowenBao
Copy link
Copy Markdown
Collaborator

BowenBao commented Apr 26, 2022

@crcrpar thanks for contributing! Could you add a test case in test/onnx/test_pytorch_onnx_onnxruntime.py for GridSample? I believe onnxruntime has supported this operator since version 1.11, which is what we are running in CI.

You will need to update version to 16 here to enable CI run for opset 16 test with onnxruntime.

for i in $(seq 10 15); do

@BowenBao
Copy link
Copy Markdown
Collaborator

@crcrpar Looks like CI onnx check is failing on a few tests

FAILED test/onnx/test_pytorch_onnx_onnxruntime.py::TestONNXRuntime_opset16::test_grid_sample
FAILED test/onnx/test_pytorch_onnx_onnxruntime.py::TestONNXRuntime_opset16::test_multi_scale_roi_align
FAILED test/onnx/test_pytorch_onnx_onnxruntime.py::TestONNXRuntime_opset16::test_roi_align
FAILED test/onnx/test_pytorch_onnx_onnxruntime.py::TestONNXRuntime_opset16::test_roi_align_aligned

The roi_align ones look like due to ONNX spec update on opset 16. If you'd like, you can wrap the test case with @skipIfUnsupportedMaxOpsetVersion(16) # TODO: Opset 16 RoiAlign result mismatch..
the grid_sample one is due to issue with passing strings and boolean as model inputs. I'd recommend rewrite it such that the configuration is set at model initialization time.

@thiagocrepaldi
Copy link
Copy Markdown
Collaborator

Fixes #69674

@crcrpar crcrpar force-pushed the onnx-opset16-gridsample branch from 0a0a42e to 5542914 Compare April 29, 2022 04:55
@BowenBao
Copy link
Copy Markdown
Collaborator

Thanks @crcrpar ! I left a some comments regarding CI failure and the local test failure you noted. Otherwise this is ready to merge.

ops = [{"op_name": "GridSample"}]
ops = {16: ops}

class MyModule(Module):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Possible to use a more descriptive name for the module? E.g. GridSampleModule, TestGridSampleModule etc.

def forward(self, x, grid, mode, padding_mode, align_corers):
return torch.nn.functional.grid_sample(x, grid, mode, padding_mode, align_corners)

for mode, padding_mode, align_corners in itertools.product(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Possible to parameterize this test? E.g.

@skip_if_lt_x_gpu(2)
@parametrize(params, configs, subtest_name)
@parametrize("clip_norm_type", [2.0, None])
def test_transformer_parameterized(self, cpu_offload, backward_prefetch, sharding_strategy, clip_norm_type):
init_modes = self._get_init_modes_for_test(cpu_offload)
for fsdp_init_mode in init_modes:
with self.subTest(fsdp_init_mode=fsdp_init_mode):
self._test_identical_outputs(
TransformerWithSharedParams,
fsdp_init_mode=fsdp_init_mode,
cpu_offload=cpu_offload,
backward_prefetch=backward_prefetch,
norm_type=clip_norm_type,
sharding_strategy=sharding_strategy,
)

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

locally parametrize didn't work so skipped it this time

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also, I'm not sure if we want a bunch of different test cases for one GridSample op.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's cool. I wouldn't worry about having different tests for an op, as long as it's clear what they are testing for. For now I would also recommend subtests: https://docs.python.org/3/library/unittest.html#distinguishing-test-iterations-using-subtests

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you know if pytest will then skip the subtests or treat them as one test instead? It seems they provide support via a plugin: https://github.com/pytest-dev/pytest-subtests. Feel free to skip any changes in this PR!

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIRC it was more similar to the latter when I used pytest for other tests with subtest.
I haven't checked the code nor sought for a plugin though

from torch.nn.functional import GRID_SAMPLE_INTERPOLATION_MODES, GRID_SAMPLE_PADDING_MODES


# note (mkozuki): Why `grid_sampler` instead of `grid_sample`?
Copy link
Copy Markdown
Collaborator

@justinchuby justinchuby Apr 29, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would add this to the docstring instead.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can't that be a follow-up thing? I didn't see any docstring in symbolic_opset files.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It can be! Totally non blocking

@justinchuby
Copy link
Copy Markdown
Collaborator

Commented just as I was passing by! They are non-blocking

@HSJ-007
Copy link
Copy Markdown

HSJ-007 commented Apr 30, 2022

Great, when can I use this?

@justinchuby
Copy link
Copy Markdown
Collaborator

justinchuby commented Apr 30, 2022

Great, when can I use this?

Thanks for asking! You should be able to find the updated version in PyTorch nightly once the pull request is merged. You can find installation guides for the nightly build on https://pytorch.org by selecting Preview(Nightly).

@HSJ-007
Copy link
Copy Markdown

HSJ-007 commented Apr 30, 2022

Wow, while I don't understand what's in it, I really need this feature

@justinchuby
Copy link
Copy Markdown
Collaborator

Wow, while I don't understand what's in it, I really need this feature

@HSJ-007 Could you share more on what you are trying to accomplish and how this feature can be useful?

@HSJ-007
Copy link
Copy Markdown

HSJ-007 commented Apr 30, 2022

I'm going to finish a project, and it's tied to my graduation. I used the STN network

# note (mkozuki): Why `grid_sampler` instead of `grid_sample`?
# Because `torch.nn.functional.grid_sample` calls `torch.grid_sampler`.
@parse_args("v", "v", "i", "i", "b")
def grid_sampler(g, input, grid, mode_enum, padding_mode_enum, align_corners):
Copy link
Copy Markdown
Collaborator

@justinchuby justinchuby May 2, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@justinchuby justinchuby requested a review from BowenBao May 2, 2022 19:29
Copy link
Copy Markdown
Collaborator

@BowenBao BowenBao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@BowenBao
Copy link
Copy Markdown
Collaborator

BowenBao commented May 2, 2022

@pytorchbot merge this

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed module: onnx Related to torch.onnx oncall: jit Add this issue/PR to JIT oncall triage queue open source release notes: onnx torch.onnx related changes that should show up in the release notes topic: improvements topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

ONNX and grid_sample layer

8 participants