-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Closed
Labels
enhancementNot as big of a feature, but technically not a bug. Should be easy to fixNot as big of a feature, but technically not a bug. Should be easy to fixoncall: jitAdd this issue/PR to JIT oncall triage queueAdd this issue/PR to JIT oncall triage queuetriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
🐛 Bug
Currently one can't append modules to a ModuleList inside the constructor as the ModuleList is converted to a _ConstModuleList.
To Reproduce
Steps to reproduce the behavior:
class MultiBox(torch.jit.ScriptModule):
__constants__ = ['loc_layers']
def __init__(self):
super(MultiBox, self).__init__()
self.loc_layers = nn.ModuleList()
for i in range(4):
self.loc_layers.append(nn.Conv2d(64, 4, kernel_size=1))
@torch.jit.script_method
def forward(self, x):
return xResults in this error:
File "C:\Python36\lib\site-packages\torch\jit\__init__.py", line 1232, in __getattr__
return Module.__getattr__(self, attr)
File "C:\Python36\lib\site-packages\torch\nn\modules\module.py", line 539, in __getattr__
type(self).__name__, name))
AttributeError: '_ConstModuleList' object has no attribute 'append'
Expected behavior
This should be possible inside the constructor.
Environment
- PyTorch Version: 1.1.0
Additional context
A workaround for this problem is the following code:
class MultiBox2(torch.jit.ScriptModule):
__constants__ = ['loc_layers']
def __init__(self):
super(MultiBox2, self).__init__()
loc_layers = nn.ModuleList()
for i in range(4):
loc_layers.append(nn.Conv2d(64, 4, kernel_size=1))
self.loc_layers = loc_layers
@torch.jit.script_method
def forward(self, x):
return xMetadata
Metadata
Assignees
Labels
enhancementNot as big of a feature, but technically not a bug. Should be easy to fixNot as big of a feature, but technically not a bug. Should be easy to fixoncall: jitAdd this issue/PR to JIT oncall triage queueAdd this issue/PR to JIT oncall triage queuetriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module