Skip to content

Conversation

@masahi
Copy link
Member

@masahi masahi commented Mar 31, 2023

Compared to the existing call node based matching & rewriting that requires a common post-dominator in a pattern (introduced in #14312), it lets us match a tree structure and replace leaf nodes or branches with new expression.

This can be immediately used for combining any number of multiple matmuls sharing the same LHS into one matmul. In Relay, we have a dedicated pass for that purpose (CombineParallelDense), but we can achieve the same thing via graph (tree) matching and rewrite.

For example, in SD UNet we have many three parallel matmul for QKV projections. In addition, there are also highly non-obvious parallel matmuls consisting of 32 or 22 of them. Those patterns can all be matched and rewritten via the following generic pattern and rewriter.
https://github.com/masahi/web-stable-diffusion/blob/unet-opt/test.py#L37-L68

I got all matmul combining for SD UNet working. It reduces the number of R.matmul from 200 to 116.
Original: https://gist.github.com/masahi/0dab4b8f53115da9c33f4352c9175a87
Rewritten: https://gist.github.com/masahi/57ab925a43d2343cbfdb79a31c5b9946 (look for R.concat followed by R.matmul to see where rewriting happened)

This is exactly the same reduction possible using Relay passes on an equivalent Relay mod.

@ganler @sunggg @psrivas2 @cyx-6 @vinx13

@tvm-bot
Copy link
Collaborator

tvm-bot commented Mar 31, 2023

Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment.

Generated by tvm-bot

@masahi masahi changed the title [Unity] Pattern-based rewriting for dataflow block. [Unity] Pattern-based rewriting for dataflow block Mar 31, 2023

from .pattern import *
from .context import *
from .rewrite import rewrite_call, rewrite_bindings
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The existing rewrite function is renamed to rewrite_call to make it clear that it is CallNode rewriting. And together with the new dataflow block rewriting function, it is put under the new file.

# the matmul1 pattern. For example, lv0 in lv0 = R.matmul(x1, w0).
# We want to replace the RHS of this binding with Q.
return {matchings[matmul1]: Q, matchings[matmul2]: K, matchings[matmul3]: V}
```
Copy link
Member Author

@masahi masahi Mar 31, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I hope the API for the rewriter makes sense and the usage is intuitive. It took me a while to workout this interface together with how the rewriting mutator should be implemented in dataflow_matcher.cc.

cc @ganler

: ctx_(ctx), rewriter_func_(rewriter_func), params_(params) {}

template <typename PatternType>
static Expr Run(PatternType pat, PackedFunc rewriter_func, Function f) {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Depending on the passed pattern type (DFPattern or PatternContext), it does either call node rewriting or dataflow block rewriting. It never does both in a single pass (obvious from the constructors).

}

// Repeat until all matchable subsets of bindings are rewritten.
BindingBlock RewriteDataflowBlockFixedPoint(BindingBlock block) {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to apply rewriting repeatedly since, for example, the same QKV projection pattern appears a number of times in a single DFB.

var_bind && unemitted_vars.count(var_bind->var.get())) {
// var_bind->value may also depend on other unemitted vars in this range
Array<Binding> prev_bindings(pending_bindings.begin(), pending_bindings.begin() + i);
EmitUsedVars(var_bind->value, prev_bindings, emitted_vars);
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I want to get rid of this recursive call and make sure we traverse pending_bindings only once. The issue is that PostOrderVisit does not look into subexpressions when it encounters the corresponding bound variable. For example, given the contrived input bindings below,

with R.dataflow():
    w0_t = R.permute_dims(w0, axes=None)
    lv0 = R.matmul(x1, w0_t)
    w1_t = R.permute_dims(w1, axes=None)
    w1_t_t = R.permute_dims(w1_t, axes=None)
    lv1 = R.matmul(x1, w1_t_t)
    w2_t = R.permute_dims(w2, axes=None)
    lv2 = R.matmul(x1, w2_t)

we need to emit all permute_dims binding before emitting concat and the combined matmul, since concat depends on all weights some of which are defined after the first matmul. When PostOrderVisit is applied on R.matmul(x1, w1_t_t), w1_t is not visited. So even though we need to emit w1_t before w1_t_t, w1_t is not put into the initial unemitted_vars set.

I think we can use AnalyzeVar2Value on the input function to get bindings, and recursively traverse the bound expression when we encounter a new unemitted var. But I find that a bit complicated for a simple job like this, so I'm looking for a simpler solution. For now I'm keeping this recursive solution that is not efficient but extremely simple.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants