-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Make various improvements to C++ API parity test harness #25828
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
zou3519
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm. Had some minor nits (feel free to ignore them all) and some questions (mostly things I'm curious about)
| else: | ||
| raise RuntimeError("Unexpected input type: {}".format(type(example_inputs))) | ||
|
|
||
| # We set all inputs to torch.nn module to requires grad, so that the backward test can always be run. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(no action items) You could also just turn requires_grad on if dtype is not integral, but I'm not sure that is sufficient and we can worry about that later when parity tests are added for the Embedding layers
| module_qualified_name='torch::nn::{}'.format(module_name), | ||
| module_option=cpp_module_option) | ||
| module_option=cpp_module_option, | ||
| extra_stmts=''.join(extra_stmts)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now for some reading comprehension: what happens if the C++ module has an extra option that the python module doesn't have?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am planning to change the internal data structure of all C++ module / optimizer options to a map, after which we will be able to check for extra C++ module options by removing all the known options from the map and see if there is anything left.
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yf225 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yf225 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
This PR makes the following improvements to C++ API parity test harness:
options_argssince we can get the list of options from the Python module constructor args.intortuplein Python module constructor args toExpandingArrayin C++ module options.(1, {2, 3}, 4)into['1', '{2, 3}', '4']forcpp_default_constructor_args._test_torch_nn_module_ctor_args.We will be able to merge #24160 and #24860 after these improvements.