Skip to content

Conversation

@mruberry
Copy link
Collaborator

Currently when _apply() is called on RNNBase (or one of its children, like LSTM), the _flat_weights attribute may or may not be updated. In particular, when using .to() and sending a module like LSTM to XLA, a third party device type, the tensors in _flat_weights will not be updated and will remain on CPU. This causes the LSTM forward to fail since the forward call receives a mix of XLA and CPU tensors.

This occurs because third party device types, like XLA, may not be a compatible shallow copy type to native tensors. When this is the case and _apply is called Module parameters are replaced, not updated. RNNBase would not sync _flat_tensors with its params in this case, and that caused the references in _flat_tensors to not reflect the module's current params.

This small change forces a resync of the _flat_tensors and the actual params on each _apply. This lets .to('xla') work for LSTMs, for example. A test will be added to PyTorch/XLA (which runs in our CI) to validate this behavior after the change appears in PyTorch.

@mruberry mruberry requested a review from apaszke as a code owner October 24, 2019 00:05
@mruberry mruberry requested review from ngimel and removed request for apaszke October 24, 2019 00:50
Copy link
Collaborator

@ngimel ngimel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mruberry has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@mruberry merged this pull request in 0c48092.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants