Skip to content

Conversation

@MilesCranmer
Copy link
Contributor

@MilesCranmer MilesCranmer commented Apr 14, 2019

This is a simple yet useful addition to the torch.nn modules: an identity module. This is a first draft - please let me know what you think and I will edit my PR.

Summary: There is no identity module - nn.Sequential() can be used, however it is argument sensitive so can't be used interchangably with any other module. This adds nn.Identity(...) which can be swapped with any module because it has dummy arguments. It's also more understandable than seeing an empty Sequential inside a model.

See discussion on #9160. The current solution is to use nn.Sequential(). However this won't work as follows:

batch_norm = nn.BatchNorm2d
if dont_use_batch_norm:
    batch_norm = Identity

Then in your network, you have:

nn.Sequential(
    ...
    batch_norm(N, momentum=0.05),
    ...
)

If you try to simply set Identity = nn.Sequential, this will fail since nn.Sequential expects modules as arguments. Of course there are many ways to get around this, including:

  • Conditionally adding modules to an existing Sequential module
  • Not using Sequential but writing the usual forward function with an if statement
  • ...

However, I think that an identity module is the most pythonic strategy, assuming you want to use nn.Sequential.

Using the very simple class (this isn't the same as the one in my commit):

class Identity(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__()
    def forward(self, x):
        return x

we can get around using nn.Sequential, and batch_norm(N, momentum=0.05) will work. There are of course other situations this would be useful.

Thank you.
Best,
Miles

Summary: There is no identity module - nn.Sequential() can be used,
however it is argument sensitive so can't be used interchangably
with any other module. This adds nn.Identity(...) which
can be swapped with any module.
Copy link
Contributor

@ezyang ezyang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure why not.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ezyang is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@ezyang merged this pull request in 30292d9.

zhangguanheng66 pushed a commit to zhangguanheng66/pytorch that referenced this pull request May 6, 2019
Summary:
This is a simple yet useful addition to the torch.nn modules: an identity module. This is a first draft - please let me know what you think and I will edit my PR.

 There is no identity module - nn.Sequential() can be used, however it is argument sensitive so can't be used interchangably with any other module. This adds nn.Identity(...) which can be swapped with any module because it has dummy arguments. It's also more understandable than seeing an empty Sequential inside a model.

See discussion on pytorch#9160. The current solution is to use nn.Sequential(). However this won't work as follows:

```python
batch_norm = nn.BatchNorm2d
if dont_use_batch_norm:
    batch_norm = Identity
```

Then in your network, you have:

```python
nn.Sequential(
    ...
    batch_norm(N, momentum=0.05),
    ...
)
```

If you try to simply set `Identity = nn.Sequential`, this will fail since `nn.Sequential` expects modules as arguments. Of course there are many ways to get around this, including:

- Conditionally adding modules to an existing Sequential module
- Not using Sequential but writing the usual `forward` function with an if statement
- ...

**However, I think that an identity module is the most pythonic strategy,** assuming you want to use nn.Sequential.

Using the very simple class (this isn't the same as the one in my commit):

```python
class Identity(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__()
    def forward(self, x):
        return x
```

we can get around using nn.Sequential, and `batch_norm(N, momentum=0.05)` will work. There are of course other situations this would be useful.

Thank you.
Best,
Miles
Pull Request resolved: pytorch#19249

Differential Revision: D15012969

Pulled By: ezyang

fbshipit-source-id: 9f47e252137a1679e306fd4c169dca832eb82c0c
@nuniz nuniz mentioned this pull request Dec 29, 2023
4 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants