Skip to content

Conversation

@NarineK
Copy link
Contributor

@NarineK NarineK commented Feb 19, 2019

Summary: Added circular padding in addition to zero padding to Conv1D, Conv2D and Conv3D based on the solution suggested in: #3858

Differential Revision: D14126416

Copy link
Contributor

@ezyang ezyang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a really good start. I think the big things for me are:

  1. Getting better test coverage for padding > 1 case. I don't really care if you have a reference implementation, but we should exercise it.
  2. Fixing the error checking in circular_pad. My suggestion is to fold it into the existing pad function.

Copy link
Contributor

@apaszke apaszke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we just have F.pad_circular to avoid having a blow up for every single dimensionality?

@ezyang
Copy link
Contributor

ezyang commented Feb 19, 2019

@apaszke I don't know how to write a generalized circular pad that works for arbitrary dimensionality, without writing the kernel out (which may eventually be a good idea, but I wouldn't block this PR on it.) Any ideas? :)

@apaszke
Copy link
Contributor

apaszke commented Feb 19, 2019

I haven't actually thought about that, but it seems like the implementations for 1d and 3d look the same modulo the number of colons (which are equivalent to slice(None)), so we should be able to do this programmatically. Also, it would be good to have better examples in the docstrings, because they don't fully explain the semantics I think. Is it correct to say that in the 2D case we add the last row at the top, the first row at the bottom, the left column on the right, and the right column on the left? What happens with the indices that are not covered by those elements?

@NarineK
Copy link
Contributor Author

NarineK commented Feb 20, 2019

Thank you @apaszke and @ezyang ! I can play with the slice(None) to find a more elegant implementation for the dimensionality.
With respect to What happens with the indices that are not covered by those elements? The other indices do not change or effect anything. Here is an example from #3858 for padding=1
Before circular padding
[[0 1 2]
[3 4 5]
[6 7 8]]

After circular padding
[[8 6 7 8 6]
[2 0 1 2 0]
[5 3 4 5 3]
[8 6 7 8 6]
[2 0 1 2 0]]

@NarineK
Copy link
Contributor Author

NarineK commented Feb 21, 2019

@ezyang and @apaszke , I've just pushed the implementation of pad_circular which works for arbitrary dimensions.
6832149#diff-c66288b9ce36978f377a1a20d32ec53dR2993
The common_nn.py tests passed but test_jit does not like input[slice(None)] in general.
I will address other comments in a separate commit.

@NarineK
Copy link
Contributor Author

NarineK commented Feb 22, 2019

After talking to jit team we came to the conclusion that currently there is no support for slice(None). They opened the following feature to work on: #17389

I'll revert back my changes and update docstring examples.

cc: @ezyang, @apaszke

@ezyang
Copy link
Contributor

ezyang commented Feb 25, 2019

@pytorchbot rebase this please

@pytorchbot
Copy link
Collaborator

There's nothing to do! This branch is already up to date with master (d76b939).

(To learn more about this bot, see Bot commands.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could also have a separate method to split left and right paddings.

@NarineK NarineK force-pushed the export-D14126416 branch from 12f0925 to 72f19e2 Compare March 1, 2019 06:55
@NarineK
Copy link
Contributor Author

NarineK commented Mar 1, 2019

@pytorchbot rebase this please

@pytorchbot
Copy link
Collaborator

Sorry, I can't merge this because there are conflicts. To merge this yourself, run the commands below:

git fetch origin master
git fetch [email protected]:NarineK/pytorch.git export-D14126416
git checkout FETCH_HEAD
git merge origin/master
git push [email protected]:NarineK/pytorch.git HEAD:export-D14126416

(To learn more about this bot, see Bot commands.)

@NarineK
Copy link
Contributor Author

NarineK commented Mar 9, 2019

Do you guys have any other comments on this PR ?
cc: @ezyang , @soumith

Summary:
Pull Request resolved: pytorch#17240

Added circular padding in addition to zero padding to Conv1D, Conv2D and Conv3D based on the solution suggested in: pytorch#3858

Differential Revision: D14126416

fbshipit-source-id: 72cdb69b8c22e33b0ba9ae515c4d119f69ea45d9


@weak_script
def pad_circular(input, padding):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

question: why didn't we add this functionality to F.pad?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @fmassa! That was a suggestion to have padding functionally at one place.
cc: @soumith , @ezyang

return F.conv1d(F.pad(input, expanded_padding, mode='circular'),
self.weight, self.bias, self.stride,
_single(0), self.dilation, self.groups)
return F.conv1d(input, self.weight, self.bias, self.stride,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we want to have the functional interface of conv also support the padding mode?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the documentation it appears padding_mode is supported, but when I try to use it tells me it is an unexpected argument. I will do the manual padding, as in the module, but it would be convenient if the operation is supported.
In my case, my filter is a tensor, not a parameter, and I would prefer not to use a module.

Thanks!

@ezyang
Copy link
Contributor

ezyang commented Jul 9, 2019 via email

@jzazo
Copy link

jzazo commented Jul 10, 2019

I am running 1.1.0. Thanks!

@thomasahle
Copy link

If I try to run
F.pad(torch.arange(4), (0,1), mode='circular')
or
F.pad(torch.arange(9).reshape(3,3), (0,1,1,0), mode='circular')
I get the error
NotImplementedError: Only 2D, 3D, 4D, 5D padding with non-constant padding are supported for now.
(Torch v 1.11)
Am I doing this wrong?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants