-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[jit] Respect order of Parameters in rnn.py #18198
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
eellison
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry what exactly was the issue previously and what is the new behavior?
torch/csrc/jit/script/init.cpp
Outdated
| ConstantParameterList(std::shared_ptr<Module> module) | ||
| : module_(std::move(module)) {} | ||
| ConstantParameterList(std::shared_ptr<Module> module, py::list params) | ||
| : module_(std::move(module)), params_(std::move(params)) {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe convert & store the params to std::string here ? i imagine that's less overhead
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It'd be more since it would duplicate the strings, the py::list is just a thin wrapper around a Python list with the same reference semantics
torch/nn/modules/rnn.py
Outdated
| return self._flat_weights | ||
|
|
||
| def _get_flat_weights_names(self): | ||
| all_weights = [[weight for weight in weights] for weights in self._all_weights] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
imo the 4 weight(s) on one line is kidn of confusing
torch/nn/modules/rnn.py
Outdated
|
|
||
| def _get_flat_weights_names(self): | ||
| all_weights = [[weight for weight in weights] for weights in self._all_weights] | ||
| return [p for layerparams in all_weights for p in layerparams] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto with the two layerparams / ps
|
@eellison clarified reasons in PR message |
eellison
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, a little hard to follow.
| @_parameter_list | ||
| def get_flat_weights(self): | ||
| def _get_flat_weights_names(self): | ||
| return [weight for weights in self._all_weights for weight in weights] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you mean rewriting? This is hard to read
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@driazati has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@driazati has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@driazati has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Summary: Previously to get a list of parameters this code was just putting them in the reverse order in which they were defined, which is not always right. This PR allows parameter lists to define the order themselves. To do this parameter lists need to have a corresponding function that provides the names of the parameters. Pull Request resolved: pytorch#18198 Differential Revision: D14966270 Pulled By: driazati fbshipit-source-id: 59331aa59408660069785906304b2088c19534b2
Previously to get a list of parameters this code was just putting them in the reverse order in which they were defined, which is not always right. This PR allows parameter lists to define the order themselves. To do this parameter lists need to have a corresponding function that provides the names of the parameters.