-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Add Python/C++ torch.nn API parity test harness #23852
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
test/common_nn.py
Outdated
| dict( | ||
| module_name='Linear', | ||
| constructor_args=(10, 8), | ||
| cpp_constructor_args='(10, 8)', |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You don't need to do this, but we should really turn this into a namedtuple at some point
zou3519
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not a full review yet, but some initial comments/questions.
zou3519
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Approach looks great, some comments
test/test_cpp_api_parity.py
Outdated
| cpp_args = [] | ||
| for module_file_name in module_file_names: | ||
| cpp_args.append(module_file_name) | ||
| for arg in args[1:]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How dynamic args is makes me a little uncomfortable, but... I guess it is okay
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah it's a bit unfortunate because the C++ test function has to take an unpacked list of input arguments.. (I briefly thought about accepting std::vector<IValue> in all torch::nn modules' forward, but it seems to need a bigger use case than just making the test nicer :/)
496826c to
7e833f3
Compare
|
@pytorchbot rebase this please |
be6c62f to
d03cad2
Compare
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yf225 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yf225 is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
This PR adds test harness for checking Python / C++ API parity for
torch.nn.Modulesubclasses. Under the hood, we use JIT tracing to transfernn.Modulestate from Python to C++, so that we can test initialization / forward / backward on Python / C++ modules with the same parameters and buffers.