-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Pruning Functionality #24076
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Pruning Functionality #24076
Conversation
8ca0a3a to
a6cf1f4
Compare
|
AFAIK @soumith might be the best one to do an initial review of this PR, please feel free to add others as well. Thanks! |
|
I've allocated time this Friday to review this |
|
@soumith any feedback? |
|
@mickypaganini i am more than half-way through review. diff is big, so took time. will finish as soon as possible. |
soumith
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks like a great first-iteration of a PR. Thanks for working on it.
docs need to be auto-generated by creating appropriate rst file in docs/source and adding autoclass / automethod entries.
made several comments in-line.
Please put changes to the PR as new commits (don't force-push). As guidance, you'd want a future large PR to be several small commits, rather than a single giant commit, as that makes things significantly easier to review.
test/test_nn.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why did you put them in brackets?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
reverted (it was a legacy change from when I imported everything from prune here)
test/test_nn.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
while I'm here, I think the whole scheme of conv.weight having a pruned weight, conv.weight_mask holding a mask and conv.weight_orig holding the original weight might not work well for the optimizer scheme.
I'll dig in to it as I review the PR further.
test/test_nn.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yea, this is doing the weight_norm trick.
This is going to break in silent and bad ways, unless users prune the weights of the network before initializing the optimizers. Actually even then it wont work, as optimizer will report the parameter() of a Conv2d to still be .weight and .bias, which wont be a leaf node (and that's not valid)
test/test_nn.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this works because you explicitly call m.parameters() after masking has happened (i guess?)
test/test_nn.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
*optimizers
torch/nn/utils/prune.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
numpy hard-dependency is a no-go in pytorch core
torch/nn/utils/prune.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
np comment as above
torch/nn/utils/prune.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
np comment as above
torch/nn/utils/prune.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
numpy as a hard-dependency is a no-go in pytorch core
torch/nn/utils/prune.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
while all of these methods were exposed for taking in a module and name, what about if we want to use these methods functionally, directly on a weight Tensor?
For example:
weight = torch.randn(20, 30)
weight_pruned = prune.random_structured(weight, amount=0.4, axis=1)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
as of now, these three lines will do it:
weight = torch.randn(20, 30)
rsp = prune.RandomStructuredPruningMethod(amount=0.4, dim=1)
pruned_weight = weight * rsp.compute_mask(weight, default_mask=torch.ones_like(weight))There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
honestly, doesn't look great.
Can we check if Module or Tensor and pack the Tensor if needed?
I asked the question because I suspect people will want to do it more often than you think.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmmk, so what's your suggested api here? what does "pack the Tensor" mean?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
first, I think prune.RandomStructuredPruningMethod is sooo much more verbosity than prune.random_structured. Can we have them align better. For example the class be called: prune.RandomStructured.
Second, I was wondering if something like: pruned_weight = rsp.prune(weight) that implicitly will do weight * rsp.compute_mask(weight, default_mask=torch.ones_like(weight))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've changed them all to align with the function names and added a .prune method to the base pruning class. I added a quick test too.
f260785 to
222e4f1
Compare
3ae7fa9 to
86be625
Compare
soumith
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
go go go!
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mickypaganini is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
* `axis` to `dim` * better support negative indexing * remove `numpy` dependence in favor of `numbers` * add serialization tests
d52bb29 to
f060f98
Compare
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mickypaganini is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
|
@mickypaganini merged this pull request in 8e8a5e0. |
Provides implementation for feature request issue #20402.
Adds pruning functionalities (structured and unstructured, local and global, as well as pruning from user-provided mask).
Associated tutorial here: pytorch/tutorials#605
cc: @soumith