Skip to content

Feature Request: Easier to extend base RNN implementation #711

@csarofeen

Description

@csarofeen

Currently base RNN class/functions are hard to extend. If someone would like to extend LSTM with new features to pytorch they would have to modify:
AutogradRNN (nn/_functions/rnn.py)
StackedRNN (nn/_functions/rnn.py)
RNNBase (nn/modules/rnn.py)
Furthermore, the default RNN implementation is restrictive, enforcing every stacked RNN layer to be exactly the same.

I was thinking it may be worthwhile to instead have an RNN driver abstract RNN base class.

RNN Driver would be similar in function to recurrent, stackedRNN, and AutogradRNN functions but would be dependent on abstract RNN class but independent of specific RNN definitions.

The code would look something like...

class RNNDriver(nn.Module):        
    #constructor could either take an RNN or list of RNN layers
    def __init__(self, inputRNN, num_layers=1):
        if not isinstance(inputRNN, list):
            self.rnns = [inputRNN]
            for i in range(num_layers-1):
                self.rnns.append(inputRNN.clone())
        else:
            assert len(inputRNN) == num_layers, "RNN list length must be equal to num_layers"
            self.rnns=inputRNN
    #Parameters call to group parameters of all layers
    def parameters(self):
        memo = set()
        for rnn in self.rnns:
            for p in rnn.parameters(memo):
                yield p
                
    def forward(self, input, train=True, batch_first=False, dropout=0, bidirectional=False):
        ...

    def initHidden(self, bsz):
        for rnn in self.rnns:
            rnn.initHidden(bsz)

    def resetHidden(self, bsz):
        for rnn in self.rnns:
            rnn.resetHidden(bsz)

    def initInference(self, bsz):    
        for rnn in self.rnns:
            rnn.initInference(bsz)

class RNNBase(nn.Module):
#Base initialization could be for a simple RNN layer or could be empty
    def __init__(self, input_size, hidden_size):
        super(RNNBase, self).__init__()
        self.hidden_size = hidden_size
        self.input_size = input_size

        self.w_ih = nn.Parameter(torch.Tensor(hidden_size, input_size))
        self.b_ih = nn.Parameter(torch.Tensor(hidden_size))
        self.w_hh = nn.Parameter(torch.Tensor(hidden_size, hidden_size))

        self.hidden = None
        self.reset_parameters()

    def reset_parameters(self, feature_size):
            stdv = 1.0 / math.sqrt(feature_size)
            for weight in self.parameters():
                weight.data.uniform_(-stdv, stdv)

    def initHidden(self, bsz):
        #Create hidden variable(s)
        raise NotImplementedError

    def resetHidden(self, bsz):
        #Re-wrap hidden variables
        raise NotImplementedError

    def initInference(self, bsz):
        #Re-wrap hidden in a variable with volatile=True
        raise NotImplementedError

    def forward(self, input):
        #Implement RNN layer and return output
        raise NotImplementedError

Metadata

Metadata

Assignees

No one assigned

    Labels

    StalefeatureA request for a proper, new feature.triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions