Skip to content

Conversation

@mickypaganini
Copy link
Contributor

@mickypaganini mickypaganini commented Aug 9, 2019

Provides implementation for feature request issue #20402.

Adds pruning functionalities (structured and unstructured, local and global, as well as pruning from user-provided mask).

Associated tutorial here: pytorch/tutorials#605

cc: @soumith

@pytorchbot pytorchbot added the module: nn Related to torch.nn label Aug 9, 2019
@mickypaganini mickypaganini force-pushed the pruning-stage branch 2 times, most recently from 8ca0a3a to a6cf1f4 Compare August 15, 2019 02:44
@ailzhang ailzhang requested a review from soumith August 15, 2019 07:17
@ailzhang
Copy link
Contributor

AFAIK @soumith might be the best one to do an initial review of this PR, please feel free to add others as well. Thanks!

@soumith
Copy link
Contributor

soumith commented Aug 20, 2019

I've allocated time this Friday to review this

@mickypaganini
Copy link
Contributor Author

@soumith any feedback?

@soumith
Copy link
Contributor

soumith commented Sep 7, 2019

@mickypaganini i am more than half-way through review. diff is big, so took time. will finish as soon as possible.

Copy link
Contributor

@soumith soumith left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like a great first-iteration of a PR. Thanks for working on it.

docs need to be auto-generated by creating appropriate rst file in docs/source and adding autoclass / automethod entries.

made several comments in-line.

Please put changes to the PR as new commits (don't force-push). As guidance, you'd want a future large PR to be several small commits, rather than a single giant commit, as that makes things significantly easier to review.

test/test_nn.py Outdated
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why did you put them in brackets?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reverted (it was a legacy change from when I imported everything from prune here)

test/test_nn.py Outdated
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

while I'm here, I think the whole scheme of conv.weight having a pruned weight, conv.weight_mask holding a mask and conv.weight_orig holding the original weight might not work well for the optimizer scheme.
I'll dig in to it as I review the PR further.

test/test_nn.py Outdated
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea, this is doing the weight_norm trick.
This is going to break in silent and bad ways, unless users prune the weights of the network before initializing the optimizers. Actually even then it wont work, as optimizer will report the parameter() of a Conv2d to still be .weight and .bias, which wont be a leaf node (and that's not valid)

test/test_nn.py Outdated
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this works because you explicitly call m.parameters() after masking has happened (i guess?)

test/test_nn.py Outdated
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

*optimizers

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

numpy hard-dependency is a no-go in pytorch core

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

np comment as above

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

np comment as above

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

numpy as a hard-dependency is a no-go in pytorch core

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

while all of these methods were exposed for taking in a module and name, what about if we want to use these methods functionally, directly on a weight Tensor?

For example:

weight = torch.randn(20, 30)
weight_pruned = prune.random_structured(weight, amount=0.4, axis=1)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as of now, these three lines will do it:

weight = torch.randn(20, 30)
rsp = prune.RandomStructuredPruningMethod(amount=0.4, dim=1)
pruned_weight = weight * rsp.compute_mask(weight, default_mask=torch.ones_like(weight))

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

honestly, doesn't look great.
Can we check if Module or Tensor and pack the Tensor if needed?

I asked the question because I suspect people will want to do it more often than you think.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmmk, so what's your suggested api here? what does "pack the Tensor" mean?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

first, I think prune.RandomStructuredPruningMethod is sooo much more verbosity than prune.random_structured. Can we have them align better. For example the class be called: prune.RandomStructured.

Second, I was wondering if something like: pruned_weight = rsp.prune(weight) that implicitly will do weight * rsp.compute_mask(weight, default_mask=torch.ones_like(weight))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've changed them all to align with the function names and added a .prune method to the base pruning class. I added a quick test too.

Copy link
Contributor

@soumith soumith left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

go go go!

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mickypaganini is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mickypaganini is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@mickypaganini merged this pull request in 8e8a5e0.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Merged module: nn Related to torch.nn

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants