Skip to content

Conversation

@yf225
Copy link
Contributor

@yf225 yf225 commented Aug 6, 2019

This PR adds test harness for checking Python / C++ API parity for torch.nn.Module subclasses. Under the hood, we use JIT tracing to transfer nn.Module state from Python to C++, so that we can test initialization / forward / backward on Python / C++ modules with the same parameters and buffers.

@pytorchbot pytorchbot added module: nn Related to torch.nn module: tests Issues related to tests (not the torch.testing module) labels Aug 6, 2019
@yf225 yf225 requested a review from zou3519 August 6, 2019 14:29
@yf225 yf225 force-pushed the cpp_parity_test branch from a198908 to 4e8aa17 Compare August 6, 2019 16:38
dict(
module_name='Linear',
constructor_args=(10, 8),
cpp_constructor_args='(10, 8)',
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don't need to do this, but we should really turn this into a namedtuple at some point

Copy link
Contributor

@zou3519 zou3519 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a full review yet, but some initial comments/questions.

@yf225 yf225 changed the title Add Python/C++ API parity test harness [WIP] Add Python/C++ API parity test harness Aug 13, 2019
Copy link
Contributor

@zou3519 zou3519 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approach looks great, some comments

cpp_args = []
for module_file_name in module_file_names:
cpp_args.append(module_file_name)
for arg in args[1:]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How dynamic args is makes me a little uncomfortable, but... I guess it is okay

Copy link
Contributor Author

@yf225 yf225 Aug 13, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah it's a bit unfortunate because the C++ test function has to take an unpacked list of input arguments.. (I briefly thought about accepting std::vector<IValue> in all torch::nn modules' forward, but it seems to need a bigger use case than just making the test nicer :/)

@pytorchbot pytorchbot added caffe2 module: build Build system issues labels Aug 13, 2019
@yf225 yf225 force-pushed the cpp_parity_test branch 3 times, most recently from 496826c to 7e833f3 Compare August 25, 2019 01:20
@yf225
Copy link
Contributor Author

yf225 commented Aug 25, 2019

@pytorchbot rebase this please

@yf225 yf225 force-pushed the cpp_parity_test branch 4 times, most recently from be6c62f to d03cad2 Compare August 26, 2019 02:03
Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yf225 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yf225 is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@yf225 merged this pull request in 1bf1970.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

caffe2 Merged module: build Build system issues module: nn Related to torch.nn module: tests Issues related to tests (not the torch.testing module)

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants