-
Notifications
You must be signed in to change notification settings - Fork 3.8k
[Unity] Pattern-based rewriting for dataflow block #14446
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment.
Generated by tvm-bot |
|
|
||
| from .pattern import * | ||
| from .context import * | ||
| from .rewrite import rewrite_call, rewrite_bindings |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The existing rewrite function is renamed to rewrite_call to make it clear that it is CallNode rewriting. And together with the new dataflow block rewriting function, it is put under the new file.
| # the matmul1 pattern. For example, lv0 in lv0 = R.matmul(x1, w0). | ||
| # We want to replace the RHS of this binding with Q. | ||
| return {matchings[matmul1]: Q, matchings[matmul2]: K, matchings[matmul3]: V} | ||
| ``` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I hope the API for the rewriter makes sense and the usage is intuitive. It took me a while to workout this interface together with how the rewriting mutator should be implemented in dataflow_matcher.cc.
cc @ganler
| : ctx_(ctx), rewriter_func_(rewriter_func), params_(params) {} | ||
|
|
||
| template <typename PatternType> | ||
| static Expr Run(PatternType pat, PackedFunc rewriter_func, Function f) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Depending on the passed pattern type (DFPattern or PatternContext), it does either call node rewriting or dataflow block rewriting. It never does both in a single pass (obvious from the constructors).
| } | ||
|
|
||
| // Repeat until all matchable subsets of bindings are rewritten. | ||
| BindingBlock RewriteDataflowBlockFixedPoint(BindingBlock block) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to apply rewriting repeatedly since, for example, the same QKV projection pattern appears a number of times in a single DFB.
| var_bind && unemitted_vars.count(var_bind->var.get())) { | ||
| // var_bind->value may also depend on other unemitted vars in this range | ||
| Array<Binding> prev_bindings(pending_bindings.begin(), pending_bindings.begin() + i); | ||
| EmitUsedVars(var_bind->value, prev_bindings, emitted_vars); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I want to get rid of this recursive call and make sure we traverse pending_bindings only once. The issue is that PostOrderVisit does not look into subexpressions when it encounters the corresponding bound variable. For example, given the contrived input bindings below,
with R.dataflow():
w0_t = R.permute_dims(w0, axes=None)
lv0 = R.matmul(x1, w0_t)
w1_t = R.permute_dims(w1, axes=None)
w1_t_t = R.permute_dims(w1_t, axes=None)
lv1 = R.matmul(x1, w1_t_t)
w2_t = R.permute_dims(w2, axes=None)
lv2 = R.matmul(x1, w2_t)
we need to emit all permute_dims binding before emitting concat and the combined matmul, since concat depends on all weights some of which are defined after the first matmul. When PostOrderVisit is applied on R.matmul(x1, w1_t_t), w1_t is not visited. So even though we need to emit w1_t before w1_t_t, w1_t is not put into the initial unemitted_vars set.
I think we can use AnalyzeVar2Value on the input function to get bindings, and recursively traverse the bound expression when we encounter a new unemitted var. But I find that a bit complicated for a simple job like this, so I'm looking for a simpler solution. For now I'm keeping this recursive solution that is not efficient but extremely simple.
Compared to the existing call node based matching & rewriting that requires a common post-dominator in a pattern (introduced in #14312), it lets us match a tree structure and replace leaf nodes or branches with new expression.
This can be immediately used for combining any number of multiple matmuls sharing the same LHS into one matmul. In Relay, we have a dedicated pass for that purpose (
CombineParallelDense), but we can achieve the same thing via graph (tree) matching and rewrite.For example, in SD UNet we have many three parallel matmul for QKV projections. In addition, there are also highly non-obvious parallel matmuls consisting of 32 or 22 of them. Those patterns can all be matched and rewritten via the following generic pattern and rewriter.
https://github.com/masahi/web-stable-diffusion/blob/unet-opt/test.py#L37-L68
I got all matmul combining for SD UNet working. It reduces the number of
R.matmulfrom 200 to 116.Original: https://gist.github.com/masahi/0dab4b8f53115da9c33f4352c9175a87
Rewritten: https://gist.github.com/masahi/57ab925a43d2343cbfdb79a31c5b9946 (look for
R.concatfollowed byR.matmulto see where rewriting happened)This is exactly the same reduction possible using Relay passes on an equivalent Relay mod.
@ganler @sunggg @psrivas2 @cyx-6 @vinx13