-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Open
Labels
StalefeatureA request for a proper, new feature.A request for a proper, new feature.triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
Currently base RNN class/functions are hard to extend. If someone would like to extend LSTM with new features to pytorch they would have to modify:
AutogradRNN (nn/_functions/rnn.py)
StackedRNN (nn/_functions/rnn.py)
RNNBase (nn/modules/rnn.py)
Furthermore, the default RNN implementation is restrictive, enforcing every stacked RNN layer to be exactly the same.
I was thinking it may be worthwhile to instead have an RNN driver abstract RNN base class.
RNN Driver would be similar in function to recurrent, stackedRNN, and AutogradRNN functions but would be dependent on abstract RNN class but independent of specific RNN definitions.
The code would look something like...
class RNNDriver(nn.Module):
#constructor could either take an RNN or list of RNN layers
def __init__(self, inputRNN, num_layers=1):
if not isinstance(inputRNN, list):
self.rnns = [inputRNN]
for i in range(num_layers-1):
self.rnns.append(inputRNN.clone())
else:
assert len(inputRNN) == num_layers, "RNN list length must be equal to num_layers"
self.rnns=inputRNN
#Parameters call to group parameters of all layers
def parameters(self):
memo = set()
for rnn in self.rnns:
for p in rnn.parameters(memo):
yield p
def forward(self, input, train=True, batch_first=False, dropout=0, bidirectional=False):
...
def initHidden(self, bsz):
for rnn in self.rnns:
rnn.initHidden(bsz)
def resetHidden(self, bsz):
for rnn in self.rnns:
rnn.resetHidden(bsz)
def initInference(self, bsz):
for rnn in self.rnns:
rnn.initInference(bsz)
class RNNBase(nn.Module):
#Base initialization could be for a simple RNN layer or could be empty
def __init__(self, input_size, hidden_size):
super(RNNBase, self).__init__()
self.hidden_size = hidden_size
self.input_size = input_size
self.w_ih = nn.Parameter(torch.Tensor(hidden_size, input_size))
self.b_ih = nn.Parameter(torch.Tensor(hidden_size))
self.w_hh = nn.Parameter(torch.Tensor(hidden_size, hidden_size))
self.hidden = None
self.reset_parameters()
def reset_parameters(self, feature_size):
stdv = 1.0 / math.sqrt(feature_size)
for weight in self.parameters():
weight.data.uniform_(-stdv, stdv)
def initHidden(self, bsz):
#Create hidden variable(s)
raise NotImplementedError
def resetHidden(self, bsz):
#Re-wrap hidden variables
raise NotImplementedError
def initInference(self, bsz):
#Re-wrap hidden in a variable with volatile=True
raise NotImplementedError
def forward(self, input):
#Implement RNN layer and return output
raise NotImplementedErrorSeanNaren, ngimel, c0nn3r, apaszke, teezeit and 25 more
Metadata
Metadata
Assignees
Labels
StalefeatureA request for a proper, new feature.A request for a proper, new feature.triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module