Skip to content

[jit] Function attributes skipping other attributes #28559

@driazati

Description

@driazati
class N(nn.Module):
    __constants__ = ['norm']

    def __init__(self, norm=None):
        super(N, self).__init__()
        self.activation = torch.nn.functional.relu  # Commenting out this line makes it work
        self.norm = norm

    def forward(self, src):
        output = src
        if self.norm is not None:
            output = self.norm(output)
        return output

class M(nn.Module):
    def __init__(self):
        super().__init__()
        encoder_norm = nn.ReLU()
        self.encoder = N(encoder_norm)

    def forward(self, x):
        return self.encoder(x)

torch.jit.script(M())

outputs

RuntimeError: 
Module 'N' has no attribute 'norm' :
at ../test.py:79:11
    def forward(self, src):
        output = src
        if self.norm is not None:
           ~~~~~~~~~ <--- HERE
            output = self.norm(output)
        return output
'N.forward' is being compiled since it was called from 'M.forward'
at ../test.py:90:28
    def forward(self, x):
        return self.encoder(x)
                            ~ <--- HERE

Removing the self.activation = .... makes it work

cc @suo

Metadata

Metadata

Assignees

Labels

oncall: jitAdd this issue/PR to JIT oncall triage queuetriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions