-
Notifications
You must be signed in to change notification settings - Fork 26.3k
skip_input for RNN #894
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
skip_input for RNN #894
Conversation
| for x_layer, y_layer in zip(rnn.all_weights, weights_val): | ||
| for x, y in zip(x_layer, y_layer): | ||
| x.data.copy_(y.data) | ||
| if x is not None and y is not None: |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| grad_output = torch.randn(batch, seq_length, hidden_size * num_directions) | ||
| grad_output = torch.randn(seq_length, batch, hidden_size * num_directions) | ||
| if skip_input: | ||
| input_val = torch.randn(seq_length, batch, hidden_size) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/backends/cudnn/rnn.py
Outdated
| for param_from, param_to in zip(layer_params_from, layer_params_to): | ||
| assert param_from.type() == param_to.type() | ||
| param_to.copy_(param_from) | ||
| if param_from is not None and param_to is not None: |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/nn/_functions/rnn.py
Outdated
| gh = F.linear(hidden, w_hh, b_hh) | ||
| i_r, i_i, i_n = gi.chunk(3, 1) | ||
| i_r, i_i, i_n = [x.squeeze(1) for x in gi.chunk(3, 1)] | ||
| h_r, h_i, h_n = gh.chunk(3, 1) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/nn/_functions/rnn.py
Outdated
| grad_weight) | ||
| if self.skip_input: | ||
| grad_weight = [tuple(w for w in layer_grad_weight if w is not None) | ||
| for layer_grad_weight in grad_weight] |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/nn/_functions/rnn.py
Outdated
| hx, cx = hidden | ||
| gates = F.linear(input, w_ih, b_ih) + F.linear(hx, w_hh, b_hh) | ||
| ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) | ||
| x_h = input.unsqueeze(1).expand(input.size(0), 4, input.size(1)) if w_ih is None else F.linear(input, w_ih, b_ih) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/nn/_functions/rnn.py
Outdated
| ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) | ||
| x_h = input.unsqueeze(1).expand(input.size(0), 4, input.size(1)) if w_ih is None else F.linear(input, w_ih, b_ih) | ||
| gates = x_h + F.linear(hx, w_hh, b_hh) | ||
| ingate, forgetgate, cellgate, outgate = [x.squeeze(1) for x in gates.chunk(4, 1)] |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/nn/_functions/rnn.py
Outdated
| gi = input.unsqueeze(1).expand(input.size(0), 3, input.size(1)) if w_ih is None else F.linear(input, w_ih, b_ih) | ||
| gh = F.linear(hidden, w_hh, b_hh) | ||
| i_r, i_i, i_n = gi.chunk(3, 1) | ||
| i_r, i_i, i_n = [x.squeeze(1) for x in gi.chunk(3, 1)] |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
# Conflicts: # torch/nn/_functions/rnn.py # torch/tensor.py
|
Back on trying to fix this, making some progress! I'm uncertain how to deal with the current issue; the bias is still being added in cudnn v6 on the input layer, when skip input is set to true which isn't the correct behaviour. @apaszke @ngimel what do you think is the best solution for this? (refer here for more info on this issue)! |
* Fixes for skip rnn * Fixes for RNN cells, patch cuDNN for true skip input behaviour
|
@apaszke, tests are passing but needs a review! Let me know of any feedback :) |
apaszke
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good for the most part!
| grad_output = make_noncontig(grad_output) | ||
| grad_hy = make_noncontig(grad_hy) | ||
| input_var = make_noncontig(input_val) | ||
| input_val = make_noncontig(input_val) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/backends/cudnn/rnn.py
Outdated
| assert param_from.type() == param_to.type() | ||
| param_to.copy_(param_from) | ||
| assert not ((param_from is None or param_from.dim() == 0) ^ (param_to is None or param_to.dim() == 0)) | ||
| if not ((param_from is None or param_from.dim() == 0) and (param_to is None or param_to.dim() == 0)): |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/backends/cudnn/rnn.py
Outdated
| if fn.skip_input: | ||
| params = get_parameters(fn, handle, w) | ||
| for layer_index in range(fn.num_directions): | ||
| params[layer_index][2].fill_(0) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/backends/cudnn/rnn.py
Outdated
| if fn.skip_input: | ||
| for layer_index in range(fn.num_directions): | ||
| params[layer_index][0] = None | ||
| params[layer_index][2] = None |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| def LSTMCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None): | ||
| if input.is_cuda: | ||
| igates = F.linear(input, w_ih) | ||
| igates = input.expand(4, input.size(0), input.size(1)).transpose(0, 1) if w_ih is None else F.linear(input, |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| return state(gi, gh, hidden) if b_ih is None else state(gi, gh, hidden, b_ih, b_hh) | ||
|
|
||
| gi = F.linear(input, w_ih, b_ih) | ||
| gi = input.expand(3, input.size(0), input.size(1)).transpose(0, 1).contiguous() if w_ih is None else \ |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/nn/_functions/rnn.py
Outdated
| i_r, i_i, i_n = gi.chunk(3, 1) | ||
| h_r, h_i, h_n = gh.chunk(3, 1) | ||
| i_r, i_i, i_n = torch.unbind(gi.view(input.size(0), 3, -1), 1) | ||
| h_r, h_i, h_n = torch.unbind(gh.view(input.size(0), 3, -1), 1) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
Made some changes as requested, but still have to figure out the |
|
Hey @apaszke any thoughts/feedback :) EDIT: Removed unbind commands with chunk |
|
I've modified the line to now work with chunks rather than unbind as previously implemented! |
|
Can I get a status on the PR (blocking some deep speech stuff) :) |
|
After speaking to peeps I'm going to close this PR in favour of a correct implementation of skip input until cuDNN addresses this! thanks @justinchiu :) |
* Fix placement of block sync with halo loop * hdiff test
* Cherrypicked the changes from pytorch#71146
Should be ready to go! Added skip_input for RNNs (look here for information). Let me know of any feedback etc!