Skip to content

Conversation

@moskomule
Copy link
Contributor

Hi, I thought nn.init is great but initializing variables in layers using nn.init is a bit complicated. So I try to implement easier way of initialization like below.

m = nn.Linear(20, 30, initializer=lambda x: init.xavier_normal(x, 1))

I do not think this is the best way and thus implemented only for Linear. If you think this is useful, give me advice and help me. Thank you.

@Kaixhin
Copy link
Contributor

Kaixhin commented Jun 17, 2017

The method you suggested imposes the same initialisation on weights and biases, and more generally, might be ambiguous with parameters in other layers (like the mean of batch normalisation).

I would suggest chaining, along the lines of the original nninit.

@apaszke
Copy link
Contributor

apaszke commented Jun 17, 2017

I agree that this is not an ideal solution, as it's quite ambiguous. Chaining is a bit annoying in Python, because it has to happen inside parenthesis (otherwise it's a syntax error because the whitespace has meaning).

@Kaixhin
Copy link
Contributor

Kaixhin commented Jun 17, 2017

@apaszke what do you mean exactly? I was thinking along the lines of:

m = nn.Linear(20, 30).init('weight', nn.init.xavier_normal, 1)

# Linear.init pseudocode:
def init(self, accessor, method, **kwargs):
    method(self[accessor], **kwargs)
    return self

First argument accesses the field, second specifies the method, rest are passed into the method. I still haven't used Python that much, but is there a problem with the above?

@apaszke
Copy link
Contributor

apaszke commented Jun 17, 2017

the above is fine, however this is a syntax error:

nn.Linear(20, 30).init('weight', ...)
                 .init('bias', ...)

You need extra parenthesis to fix it:

m = (nn.Linear(20, 30).init('weight', ...)
                      .init('bias', ...))

@apaszke
Copy link
Contributor

apaszke commented Jun 17, 2017

But I don't have any better ideas, so I guess we could add that

@Kaixhin
Copy link
Contributor

Kaixhin commented Jun 17, 2017

Ah I see what you mean now. Yeah that is a bit annoying, but having longer lines or introducing an extra parenthesis isn't the end of the world. I'm obviously biased, but I think chaining is a neat way to approach this. Sure was an improvement on my first API for nninit.

@moskomule
Copy link
Contributor Author

Thank you for your ideas. Yes, as @Kaixhin mentions my original idea did not consider bias well.
And @apaszke 's

m = (nn.Linear(20, 30).init('weight', ...).init('bias', ...))

is nice, but I don't think this is Python's way. What do you think of using dictionary instead, for example

m = nn.Linear(20, 30, initializer={'weight': ..., 'bias': ...})

?

@Kaixhin
Copy link
Contributor

Kaixhin commented Jun 18, 2017

Yep that seems more Pythonic but has the right level of flexibility 👍

Now this can be added into nn.Module directly so it can be used with convolutions, RNNs, everything. Loop over the keys and throw an error if the module doesn't contain it.

@moskomule
Copy link
Contributor Author

Which do you think is better, adding reset_paramters to nn.Module directly or modifying reset_parameters like I pushed?

@moskomule moskomule changed the title added simple initializer to Linear module added simple initializer to some modules Jun 19, 2017
def __init__(self, in_channels, out_channels, kernel_size, stride,
padding, dilation, transposed, output_padding, groups, bias):
padding, dilation, transposed, output_padding, groups, bias,
initializer=dict()):

This comment was marked as off-topic.

This comment was marked as off-topic.

if self.bias is not None:
self.bias.data.uniform_(-stdv, stdv)

if isinstance(self.initializer, function):

This comment was marked as off-topic.

else:
bias_initializer(self.bias)
weight_initializer(self.weight)
if self.bias is None:

This comment was marked as off-topic.

@moskomule
Copy link
Contributor Author

Now these layers seem not be able to pickle but I don't know how to fix...

@fmassa
Copy link
Member

fmassa commented Jun 27, 2017

@moskomule I think you can't pickle lambda functions. Do you need to add initializers to self?

@moskomule
Copy link
Contributor Author

While Python 3.6's test has passed, others haven't. Do you have any idea to solve it?

@Kaixhin
Copy link
Contributor

Kaixhin commented Jul 1, 2017

@moskomule that's not you, sorry - looks like a test broke on some other versions. Pasting stack trace here temporarily so others don't need to dive in:

======================================================================
FAIL: test_AdaptiveMaxPool2d_tuple (__main__.TestNN)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "test_nn.py", line 2852, in <lambda>
    setattr(TestNN, test_name, lambda self, test=test: test(self))
  File "/home/travis/build/pytorch/pytorch/test/common_nn.py", line 571, in __call__
    self._do_test(test_case, module, input)
  File "test_nn.py", line 82, in _do_test
    test_case.check_jacobian(module, input, self.jacobian_input)
  File "/home/travis/build/pytorch/pytorch/test/common_nn.py", line 459, in check_jacobian
    PRECISION
AssertionError: 0.10383865983731866 not less than or equal to 1e-05

def __init__(self, in_channels, out_channels, kernel_size, stride,
padding, dilation, transposed, output_padding, groups, bias):
padding, dilation, transposed, output_padding, groups, bias,
initializer):

This comment was marked as off-topic.

This comment was marked as off-topic.

dilation (int or tuple, optional): Spacing between kernel elements
groups (int, optional): Number of blocked connections from input channels to output channels
bias (bool, optional): If True, adds a learnable bias to the output
initializer(dict, optional): dictionary of initializer of weights and bias, if None (default),

This comment was marked as off-topic.

This comment was marked as off-topic.

@moskomule
Copy link
Contributor Author

Is this PR still alive, I mean, do I need to do something?

@github-actions
Copy link
Contributor

github-actions bot commented Mar 1, 2022

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a mistake.
If you are unable to remove the Stale label please contact a maintainer in order to do so.
Stale pull requests will automatically be closed 30 days after being marked Stale

@github-actions github-actions bot added the Stale label Mar 1, 2022
@pytorchbot pytorchbot removed the Stale label Apr 12, 2022
@github-actions
Copy link
Contributor

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a mistake.
If you are unable to remove the Stale label please contact a maintainer in order to do so.
If you want the bot to never mark this PR stale again, add the no-stale label.
Stale pull requests will automatically be closed after 30 days of inactivity.

@github-actions github-actions bot added the Stale label Jun 11, 2022
@github-actions github-actions bot closed this Jul 11, 2022
jjsjann123 pushed a commit to jjsjann123/pytorch that referenced this pull request Jul 15, 2022
jjsjann123 added a commit that referenced this pull request Jul 21, 2022
Syncing nvfuser devel branch to upstream master. https://github.com/csarofeen/pytorch/

Code changes includes:

- codegen improvements:
  1. Indexing refactor -> Remove reference tensor in predicate indexing logic
  2. MMA Rfactor support for cross-warp and cross-CTA split on K dimension
  3. Grouping grid allreduces across iterations
  4. Swizzle op formulation for non-affine swizzles
  5. Use scheduler_utils to cache inputs and outputs in schedulePointwise
- scheduler refactor
  1. New compute at interface
- transformation propagation refactor on MaxInfoSpanningTree
  1. Added sibling path that is required to generate consistent replay for some cases where `MaxInfoSpanningTree` is used with a selector.
  2. Optimization to skip Transform propagator
  3. SpanningTreePrinter for debugging
- parser update
  1. Fixes `div`
  2. Added `_to_copy`
  3. Broadcast in dim with expand to support expanding to concrete size
  4. Dropout prob extremal patch
- executor patch on caching strides for output allocation

Squashed commits to WAR github API
Commits that's actually in this PR from the devel branch:

```
3b87896 Fix allocation of work buffers and `fused_reduction::ParallelReduce` with unswitch (#1818)
4cae122 schedulePointwise cleanup: - computeAt + InlinePropagator (#1815)
3df9742 Use scheduler_utils to cache inputs and outputs in schedulePointwise (#1811)
03180aa improve broadcast resolution (#1792)
bee6c69 bug fix (#1819)
4413c8f Support PYTORCH_NVFUSER_DUMP=transform_propagator (#1812)
de6b7ca Fix negative position in InlinePropagator (#1813)
10a996c Remove redundant check in schedulePointwise (#1810)
acd5ed4 Swizzle op formulation for non-affine swizzles (#1441)
3ed8330 Kernel args patch to show zero_init buffer (#1809)
037a75a Dropout prob extremal patch (#1804)
282c429 spam nvrtc options (#1783)
3ba6a5f Broadcast in dim with expand (#1794)
fd4be12 remove dead indexing code (#1806)
fa4e6a4 Check siblings in getMaxPosAll (#1805)
025c840 Grouping grid allreduces across iterations (#1755)
37c579e Temporarily disable test requring large shared memory. (#1802)
5f375d0 More cleanup on InlinePropagator (#1800)
8d384da Indexing refactor stage 2 : Remove reference tensor in predicate indexing logic (#1784)
f008140 MMA Rfactor support for cross-warp and cross-CTA split on K dimension (#1554)
76b3cca Add parsing support for `_to_copy` to handle AMP casts. (#1756)
ef04f6c Coding style cleanups (#1798)
38c7f3c InlinePropagator please don't replay (#1797)
3f2c263 validateDomain in TransformPropagator (#1796)
c077085 Use TransformPropagatorWithCheck in many tests (#1795)
d0d0908 Some further cleanup for the new computeAt interface (#1793)
45f5203 Fix TransformReplay::getMatchedLeafPosWithoutReplay* (#1791)
28cbaf9 New compute at interface (#1743)
635ebfc Add SpanningTreePrinter (#1786)
59f3c32 Output allocate patch (#1790)
fe93bf5 Transform propagator skip replay when possible (#1782)
ebf23a5 Fix isIntegralType error msg (#1789)
0c82ecf Disable register reuse across serial broadcast ops (#1787)
33a824d Adding sibling path for MaxInfoSpanningTree (#1776)
86f46aa Fix div(Val, TensorView) (#1778)
d3de227 Fix FusionMaxRootDomainInfoSpanningTreePrintTwice_CUDA (#1781)
ecc7a87 Extend mma dimension and layout checking to support strided batched matmul and tensor contractions (#1761)
```

[ghstack-poisoned]
jjsjann123 added a commit that referenced this pull request Jul 21, 2022
Syncing nvfuser devel branch to upstream master. https://github.com/csarofeen/pytorch/

Code changes includes:

- codegen improvements:
  1. Indexing refactor -> Remove reference tensor in predicate indexing logic
  2. MMA Rfactor support for cross-warp and cross-CTA split on K dimension
  3. Grouping grid allreduces across iterations
  4. Swizzle op formulation for non-affine swizzles
  5. Use scheduler_utils to cache inputs and outputs in schedulePointwise
- scheduler refactor
  1. New compute at interface
- transformation propagation refactor on MaxInfoSpanningTree
  1. Added sibling path that is required to generate consistent replay for some cases where `MaxInfoSpanningTree` is used with a selector.
  2. Optimization to skip Transform propagator
  3. SpanningTreePrinter for debugging
- parser update
  1. Fixes `div`
  2. Added `_to_copy`
  3. Broadcast in dim with expand to support expanding to concrete size
  4. Dropout prob extremal patch
- executor patch on caching strides for output allocation

Squashed commits to WAR github API
Commits that's actually in this PR from the devel branch:

```
3b87896 Fix allocation of work buffers and `fused_reduction::ParallelReduce` with unswitch (#1818)
4cae122 schedulePointwise cleanup: - computeAt + InlinePropagator (#1815)
3df9742 Use scheduler_utils to cache inputs and outputs in schedulePointwise (#1811)
03180aa improve broadcast resolution (#1792)
bee6c69 bug fix (#1819)
4413c8f Support PYTORCH_NVFUSER_DUMP=transform_propagator (#1812)
de6b7ca Fix negative position in InlinePropagator (#1813)
10a996c Remove redundant check in schedulePointwise (#1810)
acd5ed4 Swizzle op formulation for non-affine swizzles (#1441)
3ed8330 Kernel args patch to show zero_init buffer (#1809)
037a75a Dropout prob extremal patch (#1804)
282c429 spam nvrtc options (#1783)
3ba6a5f Broadcast in dim with expand (#1794)
fd4be12 remove dead indexing code (#1806)
fa4e6a4 Check siblings in getMaxPosAll (#1805)
025c840 Grouping grid allreduces across iterations (#1755)
37c579e Temporarily disable test requring large shared memory. (#1802)
5f375d0 More cleanup on InlinePropagator (#1800)
8d384da Indexing refactor stage 2 : Remove reference tensor in predicate indexing logic (#1784)
f008140 MMA Rfactor support for cross-warp and cross-CTA split on K dimension (#1554)
76b3cca Add parsing support for `_to_copy` to handle AMP casts. (#1756)
ef04f6c Coding style cleanups (#1798)
38c7f3c InlinePropagator please don't replay (#1797)
3f2c263 validateDomain in TransformPropagator (#1796)
c077085 Use TransformPropagatorWithCheck in many tests (#1795)
d0d0908 Some further cleanup for the new computeAt interface (#1793)
45f5203 Fix TransformReplay::getMatchedLeafPosWithoutReplay* (#1791)
28cbaf9 New compute at interface (#1743)
635ebfc Add SpanningTreePrinter (#1786)
59f3c32 Output allocate patch (#1790)
fe93bf5 Transform propagator skip replay when possible (#1782)
ebf23a5 Fix isIntegralType error msg (#1789)
0c82ecf Disable register reuse across serial broadcast ops (#1787)
33a824d Adding sibling path for MaxInfoSpanningTree (#1776)
86f46aa Fix div(Val, TensorView) (#1778)
d3de227 Fix FusionMaxRootDomainInfoSpanningTreePrintTwice_CUDA (#1781)
ecc7a87 Extend mma dimension and layout checking to support strided batched matmul and tensor contractions (#1761)
```

ghstack-source-id: f24793f
Pull Request resolved: #81861
jjsjann123 added a commit that referenced this pull request Jul 21, 2022
Syncing nvfuser devel branch to upstream master. https://github.com/csarofeen/pytorch/

Code changes includes:

- codegen improvements:
  1. Indexing refactor -> Remove reference tensor in predicate indexing logic
  2. MMA Rfactor support for cross-warp and cross-CTA split on K dimension
  3. Grouping grid allreduces across iterations
  4. Swizzle op formulation for non-affine swizzles
  5. Use scheduler_utils to cache inputs and outputs in schedulePointwise
- scheduler refactor
  1. New compute at interface
- transformation propagation refactor on MaxInfoSpanningTree
  1. Added sibling path that is required to generate consistent replay for some cases where `MaxInfoSpanningTree` is used with a selector.
  2. Optimization to skip Transform propagator
  3. SpanningTreePrinter for debugging
- parser update
  1. Fixes `div`
  2. Added `_to_copy`
  3. Broadcast in dim with expand to support expanding to concrete size
  4. Dropout prob extremal patch
- executor patch on caching strides for output allocation

Squashed commits to WAR github API
Commits that's actually in this PR from the devel branch:

```
3b87896 Fix allocation of work buffers and `fused_reduction::ParallelReduce` with unswitch (#1818)
4cae122 schedulePointwise cleanup: - computeAt + InlinePropagator (#1815)
3df9742 Use scheduler_utils to cache inputs and outputs in schedulePointwise (#1811)
03180aa improve broadcast resolution (#1792)
bee6c69 bug fix (#1819)
4413c8f Support PYTORCH_NVFUSER_DUMP=transform_propagator (#1812)
de6b7ca Fix negative position in InlinePropagator (#1813)
10a996c Remove redundant check in schedulePointwise (#1810)
acd5ed4 Swizzle op formulation for non-affine swizzles (#1441)
3ed8330 Kernel args patch to show zero_init buffer (#1809)
037a75a Dropout prob extremal patch (#1804)
282c429 spam nvrtc options (#1783)
3ba6a5f Broadcast in dim with expand (#1794)
fd4be12 remove dead indexing code (#1806)
fa4e6a4 Check siblings in getMaxPosAll (#1805)
025c840 Grouping grid allreduces across iterations (#1755)
37c579e Temporarily disable test requring large shared memory. (#1802)
5f375d0 More cleanup on InlinePropagator (#1800)
8d384da Indexing refactor stage 2 : Remove reference tensor in predicate indexing logic (#1784)
f008140 MMA Rfactor support for cross-warp and cross-CTA split on K dimension (#1554)
76b3cca Add parsing support for `_to_copy` to handle AMP casts. (#1756)
ef04f6c Coding style cleanups (#1798)
38c7f3c InlinePropagator please don't replay (#1797)
3f2c263 validateDomain in TransformPropagator (#1796)
c077085 Use TransformPropagatorWithCheck in many tests (#1795)
d0d0908 Some further cleanup for the new computeAt interface (#1793)
45f5203 Fix TransformReplay::getMatchedLeafPosWithoutReplay* (#1791)
28cbaf9 New compute at interface (#1743)
635ebfc Add SpanningTreePrinter (#1786)
59f3c32 Output allocate patch (#1790)
fe93bf5 Transform propagator skip replay when possible (#1782)
ebf23a5 Fix isIntegralType error msg (#1789)
0c82ecf Disable register reuse across serial broadcast ops (#1787)
33a824d Adding sibling path for MaxInfoSpanningTree (#1776)
86f46aa Fix div(Val, TensorView) (#1778)
d3de227 Fix FusionMaxRootDomainInfoSpanningTreePrintTwice_CUDA (#1781)
ecc7a87 Extend mma dimension and layout checking to support strided batched matmul and tensor contractions (#1761)
```

RUN_TORCHBENCH: nvfuser

Differential Revision: [D38043938](https://our.internmc.facebook.com/intern/diff/D38043938)

[ghstack-poisoned]
jjsjann123 added a commit that referenced this pull request Jul 21, 2022
Syncing nvfuser devel branch to upstream master. https://github.com/csarofeen/pytorch/

Code changes includes:

- codegen improvements:
  1. Indexing refactor -> Remove reference tensor in predicate indexing logic
  2. MMA Rfactor support for cross-warp and cross-CTA split on K dimension
  3. Grouping grid allreduces across iterations
  4. Swizzle op formulation for non-affine swizzles
  5. Use scheduler_utils to cache inputs and outputs in schedulePointwise
- scheduler refactor
  1. New compute at interface
- transformation propagation refactor on MaxInfoSpanningTree
  1. Added sibling path that is required to generate consistent replay for some cases where `MaxInfoSpanningTree` is used with a selector.
  2. Optimization to skip Transform propagator
  3. SpanningTreePrinter for debugging
- parser update
  1. Fixes `div`
  2. Added `_to_copy`
  3. Broadcast in dim with expand to support expanding to concrete size
  4. Dropout prob extremal patch
- executor patch on caching strides for output allocation

Squashed commits to WAR github API
Commits that's actually in this PR from the devel branch:

```
3b87896 Fix allocation of work buffers and `fused_reduction::ParallelReduce` with unswitch (#1818)
4cae122 schedulePointwise cleanup: - computeAt + InlinePropagator (#1815)
3df9742 Use scheduler_utils to cache inputs and outputs in schedulePointwise (#1811)
03180aa improve broadcast resolution (#1792)
bee6c69 bug fix (#1819)
4413c8f Support PYTORCH_NVFUSER_DUMP=transform_propagator (#1812)
de6b7ca Fix negative position in InlinePropagator (#1813)
10a996c Remove redundant check in schedulePointwise (#1810)
acd5ed4 Swizzle op formulation for non-affine swizzles (#1441)
3ed8330 Kernel args patch to show zero_init buffer (#1809)
037a75a Dropout prob extremal patch (#1804)
282c429 spam nvrtc options (#1783)
3ba6a5f Broadcast in dim with expand (#1794)
fd4be12 remove dead indexing code (#1806)
fa4e6a4 Check siblings in getMaxPosAll (#1805)
025c840 Grouping grid allreduces across iterations (#1755)
37c579e Temporarily disable test requring large shared memory. (#1802)
5f375d0 More cleanup on InlinePropagator (#1800)
8d384da Indexing refactor stage 2 : Remove reference tensor in predicate indexing logic (#1784)
f008140 MMA Rfactor support for cross-warp and cross-CTA split on K dimension (#1554)
76b3cca Add parsing support for `_to_copy` to handle AMP casts. (#1756)
ef04f6c Coding style cleanups (#1798)
38c7f3c InlinePropagator please don't replay (#1797)
3f2c263 validateDomain in TransformPropagator (#1796)
c077085 Use TransformPropagatorWithCheck in many tests (#1795)
d0d0908 Some further cleanup for the new computeAt interface (#1793)
45f5203 Fix TransformReplay::getMatchedLeafPosWithoutReplay* (#1791)
28cbaf9 New compute at interface (#1743)
635ebfc Add SpanningTreePrinter (#1786)
59f3c32 Output allocate patch (#1790)
fe93bf5 Transform propagator skip replay when possible (#1782)
ebf23a5 Fix isIntegralType error msg (#1789)
0c82ecf Disable register reuse across serial broadcast ops (#1787)
33a824d Adding sibling path for MaxInfoSpanningTree (#1776)
86f46aa Fix div(Val, TensorView) (#1778)
d3de227 Fix FusionMaxRootDomainInfoSpanningTreePrintTwice_CUDA (#1781)
ecc7a87 Extend mma dimension and layout checking to support strided batched matmul and tensor contractions (#1761)
```

RUN_TORCHBENCH: nvfuser

ghstack-source-id: cfd5278
Pull Request resolved: #81861
jjsjann123 added a commit that referenced this pull request Jul 23, 2022
Syncing nvfuser devel branch to upstream master. https://github.com/csarofeen/pytorch/

Code changes includes:

- codegen improvements:
  1. Indexing refactor -> Remove reference tensor in predicate indexing logic
  2. MMA Rfactor support for cross-warp and cross-CTA split on K dimension
  3. Grouping grid allreduces across iterations
  4. Swizzle op formulation for non-affine swizzles
  5. Use scheduler_utils to cache inputs and outputs in schedulePointwise
- scheduler refactor
  1. New compute at interface
- transformation propagation refactor on MaxInfoSpanningTree
  1. Added sibling path that is required to generate consistent replay for some cases where `MaxInfoSpanningTree` is used with a selector.
  2. Optimization to skip Transform propagator
  3. SpanningTreePrinter for debugging
- parser update
  1. Fixes `div`
  2. Added `_to_copy`
  3. Broadcast in dim with expand to support expanding to concrete size
  4. Dropout prob extremal patch
- executor patch on caching strides for output allocation

Squashed commits to WAR github API
Commits that's actually in this PR from the devel branch:

```
3b87896 Fix allocation of work buffers and `fused_reduction::ParallelReduce` with unswitch (#1818)
4cae122 schedulePointwise cleanup: - computeAt + InlinePropagator (#1815)
3df9742 Use scheduler_utils to cache inputs and outputs in schedulePointwise (#1811)
03180aa improve broadcast resolution (#1792)
bee6c69 bug fix (#1819)
4413c8f Support PYTORCH_NVFUSER_DUMP=transform_propagator (#1812)
de6b7ca Fix negative position in InlinePropagator (#1813)
10a996c Remove redundant check in schedulePointwise (#1810)
acd5ed4 Swizzle op formulation for non-affine swizzles (#1441)
3ed8330 Kernel args patch to show zero_init buffer (#1809)
037a75a Dropout prob extremal patch (#1804)
282c429 spam nvrtc options (#1783)
3ba6a5f Broadcast in dim with expand (#1794)
fd4be12 remove dead indexing code (#1806)
fa4e6a4 Check siblings in getMaxPosAll (#1805)
025c840 Grouping grid allreduces across iterations (#1755)
37c579e Temporarily disable test requring large shared memory. (#1802)
5f375d0 More cleanup on InlinePropagator (#1800)
8d384da Indexing refactor stage 2 : Remove reference tensor in predicate indexing logic (#1784)
f008140 MMA Rfactor support for cross-warp and cross-CTA split on K dimension (#1554)
76b3cca Add parsing support for `_to_copy` to handle AMP casts. (#1756)
ef04f6c Coding style cleanups (#1798)
38c7f3c InlinePropagator please don't replay (#1797)
3f2c263 validateDomain in TransformPropagator (#1796)
c077085 Use TransformPropagatorWithCheck in many tests (#1795)
d0d0908 Some further cleanup for the new computeAt interface (#1793)
45f5203 Fix TransformReplay::getMatchedLeafPosWithoutReplay* (#1791)
28cbaf9 New compute at interface (#1743)
635ebfc Add SpanningTreePrinter (#1786)
59f3c32 Output allocate patch (#1790)
fe93bf5 Transform propagator skip replay when possible (#1782)
ebf23a5 Fix isIntegralType error msg (#1789)
0c82ecf Disable register reuse across serial broadcast ops (#1787)
33a824d Adding sibling path for MaxInfoSpanningTree (#1776)
86f46aa Fix div(Val, TensorView) (#1778)
d3de227 Fix FusionMaxRootDomainInfoSpanningTreePrintTwice_CUDA (#1781)
ecc7a87 Extend mma dimension and layout checking to support strided batched matmul and tensor contractions (#1761)
```

RUN_TORCHBENCH: nvfuser

Differential Revision: [D38043938](https://our.internmc.facebook.com/intern/diff/D38043938)

[ghstack-poisoned]
jjsjann123 added a commit that referenced this pull request Jul 23, 2022
Syncing nvfuser devel branch to upstream master. https://github.com/csarofeen/pytorch/

Code changes includes:

- codegen improvements:
  1. Indexing refactor -> Remove reference tensor in predicate indexing logic
  2. MMA Rfactor support for cross-warp and cross-CTA split on K dimension
  3. Grouping grid allreduces across iterations
  4. Swizzle op formulation for non-affine swizzles
  5. Use scheduler_utils to cache inputs and outputs in schedulePointwise
- scheduler refactor
  1. New compute at interface
- transformation propagation refactor on MaxInfoSpanningTree
  1. Added sibling path that is required to generate consistent replay for some cases where `MaxInfoSpanningTree` is used with a selector.
  2. Optimization to skip Transform propagator
  3. SpanningTreePrinter for debugging
- parser update
  1. Fixes `div`
  2. Added `_to_copy`
  3. Broadcast in dim with expand to support expanding to concrete size
  4. Dropout prob extremal patch
- executor patch on caching strides for output allocation

Squashed commits to WAR github API
Commits that's actually in this PR from the devel branch:

```
3b87896 Fix allocation of work buffers and `fused_reduction::ParallelReduce` with unswitch (#1818)
4cae122 schedulePointwise cleanup: - computeAt + InlinePropagator (#1815)
3df9742 Use scheduler_utils to cache inputs and outputs in schedulePointwise (#1811)
03180aa improve broadcast resolution (#1792)
bee6c69 bug fix (#1819)
4413c8f Support PYTORCH_NVFUSER_DUMP=transform_propagator (#1812)
de6b7ca Fix negative position in InlinePropagator (#1813)
10a996c Remove redundant check in schedulePointwise (#1810)
acd5ed4 Swizzle op formulation for non-affine swizzles (#1441)
3ed8330 Kernel args patch to show zero_init buffer (#1809)
037a75a Dropout prob extremal patch (#1804)
282c429 spam nvrtc options (#1783)
3ba6a5f Broadcast in dim with expand (#1794)
fd4be12 remove dead indexing code (#1806)
fa4e6a4 Check siblings in getMaxPosAll (#1805)
025c840 Grouping grid allreduces across iterations (#1755)
37c579e Temporarily disable test requring large shared memory. (#1802)
5f375d0 More cleanup on InlinePropagator (#1800)
8d384da Indexing refactor stage 2 : Remove reference tensor in predicate indexing logic (#1784)
f008140 MMA Rfactor support for cross-warp and cross-CTA split on K dimension (#1554)
76b3cca Add parsing support for `_to_copy` to handle AMP casts. (#1756)
ef04f6c Coding style cleanups (#1798)
38c7f3c InlinePropagator please don't replay (#1797)
3f2c263 validateDomain in TransformPropagator (#1796)
c077085 Use TransformPropagatorWithCheck in many tests (#1795)
d0d0908 Some further cleanup for the new computeAt interface (#1793)
45f5203 Fix TransformReplay::getMatchedLeafPosWithoutReplay* (#1791)
28cbaf9 New compute at interface (#1743)
635ebfc Add SpanningTreePrinter (#1786)
59f3c32 Output allocate patch (#1790)
fe93bf5 Transform propagator skip replay when possible (#1782)
ebf23a5 Fix isIntegralType error msg (#1789)
0c82ecf Disable register reuse across serial broadcast ops (#1787)
33a824d Adding sibling path for MaxInfoSpanningTree (#1776)
86f46aa Fix div(Val, TensorView) (#1778)
d3de227 Fix FusionMaxRootDomainInfoSpanningTreePrintTwice_CUDA (#1781)
ecc7a87 Extend mma dimension and layout checking to support strided batched matmul and tensor contractions (#1761)
```

RUN_TORCHBENCH: nvfuser

ghstack-source-id: 93c6b1e
Pull Request resolved: #81861
jjsjann123 added a commit that referenced this pull request Jul 26, 2022
Syncing nvfuser devel branch to upstream master. https://github.com/csarofeen/pytorch/

Code changes includes:

- codegen improvements:
  1. Indexing refactor -> Remove reference tensor in predicate indexing logic
  2. MMA Rfactor support for cross-warp and cross-CTA split on K dimension
  3. Grouping grid allreduces across iterations
  4. Swizzle op formulation for non-affine swizzles
  5. Use scheduler_utils to cache inputs and outputs in schedulePointwise
- scheduler refactor
  1. New compute at interface
- transformation propagation refactor on MaxInfoSpanningTree
  1. Added sibling path that is required to generate consistent replay for some cases where `MaxInfoSpanningTree` is used with a selector.
  2. Optimization to skip Transform propagator
  3. SpanningTreePrinter for debugging
- parser update
  1. Fixes `div`
  2. Added `_to_copy`
  3. Broadcast in dim with expand to support expanding to concrete size
  4. Dropout prob extremal patch
- executor patch on caching strides for output allocation

Squashed commits to WAR github API
Commits that's actually in this PR from the devel branch:

```
3b87896 Fix allocation of work buffers and `fused_reduction::ParallelReduce` with unswitch (#1818)
4cae122 schedulePointwise cleanup: - computeAt + InlinePropagator (#1815)
3df9742 Use scheduler_utils to cache inputs and outputs in schedulePointwise (#1811)
03180aa improve broadcast resolution (#1792)
bee6c69 bug fix (#1819)
4413c8f Support PYTORCH_NVFUSER_DUMP=transform_propagator (#1812)
de6b7ca Fix negative position in InlinePropagator (#1813)
10a996c Remove redundant check in schedulePointwise (#1810)
acd5ed4 Swizzle op formulation for non-affine swizzles (#1441)
3ed8330 Kernel args patch to show zero_init buffer (#1809)
037a75a Dropout prob extremal patch (#1804)
282c429 spam nvrtc options (#1783)
3ba6a5f Broadcast in dim with expand (#1794)
fd4be12 remove dead indexing code (#1806)
fa4e6a4 Check siblings in getMaxPosAll (#1805)
025c840 Grouping grid allreduces across iterations (#1755)
37c579e Temporarily disable test requring large shared memory. (#1802)
5f375d0 More cleanup on InlinePropagator (#1800)
8d384da Indexing refactor stage 2 : Remove reference tensor in predicate indexing logic (#1784)
f008140 MMA Rfactor support for cross-warp and cross-CTA split on K dimension (#1554)
76b3cca Add parsing support for `_to_copy` to handle AMP casts. (#1756)
ef04f6c Coding style cleanups (#1798)
38c7f3c InlinePropagator please don't replay (#1797)
3f2c263 validateDomain in TransformPropagator (#1796)
c077085 Use TransformPropagatorWithCheck in many tests (#1795)
d0d0908 Some further cleanup for the new computeAt interface (#1793)
45f5203 Fix TransformReplay::getMatchedLeafPosWithoutReplay* (#1791)
28cbaf9 New compute at interface (#1743)
635ebfc Add SpanningTreePrinter (#1786)
59f3c32 Output allocate patch (#1790)
fe93bf5 Transform propagator skip replay when possible (#1782)
ebf23a5 Fix isIntegralType error msg (#1789)
0c82ecf Disable register reuse across serial broadcast ops (#1787)
33a824d Adding sibling path for MaxInfoSpanningTree (#1776)
86f46aa Fix div(Val, TensorView) (#1778)
d3de227 Fix FusionMaxRootDomainInfoSpanningTreePrintTwice_CUDA (#1781)
ecc7a87 Extend mma dimension and layout checking to support strided batched matmul and tensor contractions (#1761)
```

RUN_TORCHBENCH: nvfuser

Differential Revision: [D38043938](https://our.internmc.facebook.com/intern/diff/D38043938)

[ghstack-poisoned]
jjsjann123 added a commit that referenced this pull request Jul 26, 2022
Syncing nvfuser devel branch to upstream master. https://github.com/csarofeen/pytorch/

Code changes includes:

- codegen improvements:
  1. Indexing refactor -> Remove reference tensor in predicate indexing logic
  2. MMA Rfactor support for cross-warp and cross-CTA split on K dimension
  3. Grouping grid allreduces across iterations
  4. Swizzle op formulation for non-affine swizzles
  5. Use scheduler_utils to cache inputs and outputs in schedulePointwise
- scheduler refactor
  1. New compute at interface
- transformation propagation refactor on MaxInfoSpanningTree
  1. Added sibling path that is required to generate consistent replay for some cases where `MaxInfoSpanningTree` is used with a selector.
  2. Optimization to skip Transform propagator
  3. SpanningTreePrinter for debugging
- parser update
  1. Fixes `div`
  2. Added `_to_copy`
  3. Broadcast in dim with expand to support expanding to concrete size
  4. Dropout prob extremal patch
- executor patch on caching strides for output allocation

Squashed commits to WAR github API
Commits that's actually in this PR from the devel branch:

```
3b87896 Fix allocation of work buffers and `fused_reduction::ParallelReduce` with unswitch (#1818)
4cae122 schedulePointwise cleanup: - computeAt + InlinePropagator (#1815)
3df9742 Use scheduler_utils to cache inputs and outputs in schedulePointwise (#1811)
03180aa improve broadcast resolution (#1792)
bee6c69 bug fix (#1819)
4413c8f Support PYTORCH_NVFUSER_DUMP=transform_propagator (#1812)
de6b7ca Fix negative position in InlinePropagator (#1813)
10a996c Remove redundant check in schedulePointwise (#1810)
acd5ed4 Swizzle op formulation for non-affine swizzles (#1441)
3ed8330 Kernel args patch to show zero_init buffer (#1809)
037a75a Dropout prob extremal patch (#1804)
282c429 spam nvrtc options (#1783)
3ba6a5f Broadcast in dim with expand (#1794)
fd4be12 remove dead indexing code (#1806)
fa4e6a4 Check siblings in getMaxPosAll (#1805)
025c840 Grouping grid allreduces across iterations (#1755)
37c579e Temporarily disable test requring large shared memory. (#1802)
5f375d0 More cleanup on InlinePropagator (#1800)
8d384da Indexing refactor stage 2 : Remove reference tensor in predicate indexing logic (#1784)
f008140 MMA Rfactor support for cross-warp and cross-CTA split on K dimension (#1554)
76b3cca Add parsing support for `_to_copy` to handle AMP casts. (#1756)
ef04f6c Coding style cleanups (#1798)
38c7f3c InlinePropagator please don't replay (#1797)
3f2c263 validateDomain in TransformPropagator (#1796)
c077085 Use TransformPropagatorWithCheck in many tests (#1795)
d0d0908 Some further cleanup for the new computeAt interface (#1793)
45f5203 Fix TransformReplay::getMatchedLeafPosWithoutReplay* (#1791)
28cbaf9 New compute at interface (#1743)
635ebfc Add SpanningTreePrinter (#1786)
59f3c32 Output allocate patch (#1790)
fe93bf5 Transform propagator skip replay when possible (#1782)
ebf23a5 Fix isIntegralType error msg (#1789)
0c82ecf Disable register reuse across serial broadcast ops (#1787)
33a824d Adding sibling path for MaxInfoSpanningTree (#1776)
86f46aa Fix div(Val, TensorView) (#1778)
d3de227 Fix FusionMaxRootDomainInfoSpanningTreePrintTwice_CUDA (#1781)
ecc7a87 Extend mma dimension and layout checking to support strided batched matmul and tensor contractions (#1761)
```

RUN_TORCHBENCH: nvfuser

Differential Revision: [D38043938](https://our.internmc.facebook.com/intern/diff/D38043938)

[ghstack-poisoned]
jjsjann123 added a commit that referenced this pull request Jul 27, 2022
Syncing nvfuser devel branch to upstream master. https://github.com/csarofeen/pytorch/

Code changes includes:

- codegen improvements:
  1. Indexing refactor -> Remove reference tensor in predicate indexing logic
  2. MMA Rfactor support for cross-warp and cross-CTA split on K dimension
  3. Grouping grid allreduces across iterations
  4. Swizzle op formulation for non-affine swizzles
  5. Use scheduler_utils to cache inputs and outputs in schedulePointwise
- scheduler refactor
  1. New compute at interface
- transformation propagation refactor on MaxInfoSpanningTree
  1. Added sibling path that is required to generate consistent replay for some cases where `MaxInfoSpanningTree` is used with a selector.
  2. Optimization to skip Transform propagator
  3. SpanningTreePrinter for debugging
- parser update
  1. Fixes `div`
  2. Added `_to_copy`
  3. Broadcast in dim with expand to support expanding to concrete size
  4. Dropout prob extremal patch
- executor patch on caching strides for output allocation

Squashed commits to WAR github API
Commits that's actually in this PR from the devel branch:

```
3b87896 Fix allocation of work buffers and `fused_reduction::ParallelReduce` with unswitch (#1818)
4cae122 schedulePointwise cleanup: - computeAt + InlinePropagator (#1815)
3df9742 Use scheduler_utils to cache inputs and outputs in schedulePointwise (#1811)
03180aa improve broadcast resolution (#1792)
bee6c69 bug fix (#1819)
4413c8f Support PYTORCH_NVFUSER_DUMP=transform_propagator (#1812)
de6b7ca Fix negative position in InlinePropagator (#1813)
10a996c Remove redundant check in schedulePointwise (#1810)
acd5ed4 Swizzle op formulation for non-affine swizzles (#1441)
3ed8330 Kernel args patch to show zero_init buffer (#1809)
037a75a Dropout prob extremal patch (#1804)
282c429 spam nvrtc options (#1783)
3ba6a5f Broadcast in dim with expand (#1794)
fd4be12 remove dead indexing code (#1806)
fa4e6a4 Check siblings in getMaxPosAll (#1805)
025c840 Grouping grid allreduces across iterations (#1755)
37c579e Temporarily disable test requring large shared memory. (#1802)
5f375d0 More cleanup on InlinePropagator (#1800)
8d384da Indexing refactor stage 2 : Remove reference tensor in predicate indexing logic (#1784)
f008140 MMA Rfactor support for cross-warp and cross-CTA split on K dimension (#1554)
76b3cca Add parsing support for `_to_copy` to handle AMP casts. (#1756)
ef04f6c Coding style cleanups (#1798)
38c7f3c InlinePropagator please don't replay (#1797)
3f2c263 validateDomain in TransformPropagator (#1796)
c077085 Use TransformPropagatorWithCheck in many tests (#1795)
d0d0908 Some further cleanup for the new computeAt interface (#1793)
45f5203 Fix TransformReplay::getMatchedLeafPosWithoutReplay* (#1791)
28cbaf9 New compute at interface (#1743)
635ebfc Add SpanningTreePrinter (#1786)
59f3c32 Output allocate patch (#1790)
fe93bf5 Transform propagator skip replay when possible (#1782)
ebf23a5 Fix isIntegralType error msg (#1789)
0c82ecf Disable register reuse across serial broadcast ops (#1787)
33a824d Adding sibling path for MaxInfoSpanningTree (#1776)
86f46aa Fix div(Val, TensorView) (#1778)
d3de227 Fix FusionMaxRootDomainInfoSpanningTreePrintTwice_CUDA (#1781)
ecc7a87 Extend mma dimension and layout checking to support strided batched matmul and tensor contractions (#1761)
```

RUN_TORCHBENCH: nvfuser

Differential Revision: [D38043938](https://our.internmc.facebook.com/intern/diff/D38043938)

[ghstack-poisoned]
jjsjann123 added a commit that referenced this pull request Jul 27, 2022
Syncing nvfuser devel branch to upstream master. https://github.com/csarofeen/pytorch/

Code changes includes:

- codegen improvements:
  1. Indexing refactor -> Remove reference tensor in predicate indexing logic
  2. MMA Rfactor support for cross-warp and cross-CTA split on K dimension
  3. Grouping grid allreduces across iterations
  4. Swizzle op formulation for non-affine swizzles
  5. Use scheduler_utils to cache inputs and outputs in schedulePointwise
- scheduler refactor
  1. New compute at interface
- transformation propagation refactor on MaxInfoSpanningTree
  1. Added sibling path that is required to generate consistent replay for some cases where `MaxInfoSpanningTree` is used with a selector.
  2. Optimization to skip Transform propagator
  3. SpanningTreePrinter for debugging
- parser update
  1. Fixes `div`
  2. Added `_to_copy`
  3. Broadcast in dim with expand to support expanding to concrete size
  4. Dropout prob extremal patch
- executor patch on caching strides for output allocation

Squashed commits to WAR github API
Commits that's actually in this PR from the devel branch:

```
3b87896 Fix allocation of work buffers and `fused_reduction::ParallelReduce` with unswitch (#1818)
4cae122 schedulePointwise cleanup: - computeAt + InlinePropagator (#1815)
3df9742 Use scheduler_utils to cache inputs and outputs in schedulePointwise (#1811)
03180aa improve broadcast resolution (#1792)
bee6c69 bug fix (#1819)
4413c8f Support PYTORCH_NVFUSER_DUMP=transform_propagator (#1812)
de6b7ca Fix negative position in InlinePropagator (#1813)
10a996c Remove redundant check in schedulePointwise (#1810)
acd5ed4 Swizzle op formulation for non-affine swizzles (#1441)
3ed8330 Kernel args patch to show zero_init buffer (#1809)
037a75a Dropout prob extremal patch (#1804)
282c429 spam nvrtc options (#1783)
3ba6a5f Broadcast in dim with expand (#1794)
fd4be12 remove dead indexing code (#1806)
fa4e6a4 Check siblings in getMaxPosAll (#1805)
025c840 Grouping grid allreduces across iterations (#1755)
37c579e Temporarily disable test requring large shared memory. (#1802)
5f375d0 More cleanup on InlinePropagator (#1800)
8d384da Indexing refactor stage 2 : Remove reference tensor in predicate indexing logic (#1784)
f008140 MMA Rfactor support for cross-warp and cross-CTA split on K dimension (#1554)
76b3cca Add parsing support for `_to_copy` to handle AMP casts. (#1756)
ef04f6c Coding style cleanups (#1798)
38c7f3c InlinePropagator please don't replay (#1797)
3f2c263 validateDomain in TransformPropagator (#1796)
c077085 Use TransformPropagatorWithCheck in many tests (#1795)
d0d0908 Some further cleanup for the new computeAt interface (#1793)
45f5203 Fix TransformReplay::getMatchedLeafPosWithoutReplay* (#1791)
28cbaf9 New compute at interface (#1743)
635ebfc Add SpanningTreePrinter (#1786)
59f3c32 Output allocate patch (#1790)
fe93bf5 Transform propagator skip replay when possible (#1782)
ebf23a5 Fix isIntegralType error msg (#1789)
0c82ecf Disable register reuse across serial broadcast ops (#1787)
33a824d Adding sibling path for MaxInfoSpanningTree (#1776)
86f46aa Fix div(Val, TensorView) (#1778)
d3de227 Fix FusionMaxRootDomainInfoSpanningTreePrintTwice_CUDA (#1781)
ecc7a87 Extend mma dimension and layout checking to support strided batched matmul and tensor contractions (#1761)
```

RUN_TORCHBENCH: nvfuser

Differential Revision: [D38043938](https://our.internmc.facebook.com/intern/diff/D38043938)

[ghstack-poisoned]
jjsjann123 added a commit that referenced this pull request Jul 27, 2022
Syncing nvfuser devel branch to upstream master. https://github.com/csarofeen/pytorch/

Code changes includes:

- codegen improvements:
  1. Indexing refactor -> Remove reference tensor in predicate indexing logic
  2. MMA Rfactor support for cross-warp and cross-CTA split on K dimension
  3. Grouping grid allreduces across iterations
  4. Swizzle op formulation for non-affine swizzles
  5. Use scheduler_utils to cache inputs and outputs in schedulePointwise
- scheduler refactor
  1. New compute at interface
- transformation propagation refactor on MaxInfoSpanningTree
  1. Added sibling path that is required to generate consistent replay for some cases where `MaxInfoSpanningTree` is used with a selector.
  2. Optimization to skip Transform propagator
  3. SpanningTreePrinter for debugging
- parser update
  1. Fixes `div`
  2. Added `_to_copy`
  3. Broadcast in dim with expand to support expanding to concrete size
  4. Dropout prob extremal patch
- executor patch on caching strides for output allocation

Squashed commits to WAR github API
Commits that's actually in this PR from the devel branch:

```
3b87896 Fix allocation of work buffers and `fused_reduction::ParallelReduce` with unswitch (#1818)
4cae122 schedulePointwise cleanup: - computeAt + InlinePropagator (#1815)
3df9742 Use scheduler_utils to cache inputs and outputs in schedulePointwise (#1811)
03180aa improve broadcast resolution (#1792)
bee6c69 bug fix (#1819)
4413c8f Support PYTORCH_NVFUSER_DUMP=transform_propagator (#1812)
de6b7ca Fix negative position in InlinePropagator (#1813)
10a996c Remove redundant check in schedulePointwise (#1810)
acd5ed4 Swizzle op formulation for non-affine swizzles (#1441)
3ed8330 Kernel args patch to show zero_init buffer (#1809)
037a75a Dropout prob extremal patch (#1804)
282c429 spam nvrtc options (#1783)
3ba6a5f Broadcast in dim with expand (#1794)
fd4be12 remove dead indexing code (#1806)
fa4e6a4 Check siblings in getMaxPosAll (#1805)
025c840 Grouping grid allreduces across iterations (#1755)
37c579e Temporarily disable test requring large shared memory. (#1802)
5f375d0 More cleanup on InlinePropagator (#1800)
8d384da Indexing refactor stage 2 : Remove reference tensor in predicate indexing logic (#1784)
f008140 MMA Rfactor support for cross-warp and cross-CTA split on K dimension (#1554)
76b3cca Add parsing support for `_to_copy` to handle AMP casts. (#1756)
ef04f6c Coding style cleanups (#1798)
38c7f3c InlinePropagator please don't replay (#1797)
3f2c263 validateDomain in TransformPropagator (#1796)
c077085 Use TransformPropagatorWithCheck in many tests (#1795)
d0d0908 Some further cleanup for the new computeAt interface (#1793)
45f5203 Fix TransformReplay::getMatchedLeafPosWithoutReplay* (#1791)
28cbaf9 New compute at interface (#1743)
635ebfc Add SpanningTreePrinter (#1786)
59f3c32 Output allocate patch (#1790)
fe93bf5 Transform propagator skip replay when possible (#1782)
ebf23a5 Fix isIntegralType error msg (#1789)
0c82ecf Disable register reuse across serial broadcast ops (#1787)
33a824d Adding sibling path for MaxInfoSpanningTree (#1776)
86f46aa Fix div(Val, TensorView) (#1778)
d3de227 Fix FusionMaxRootDomainInfoSpanningTreePrintTwice_CUDA (#1781)
ecc7a87 Extend mma dimension and layout checking to support strided batched matmul and tensor contractions (#1761)
```

RUN_TORCHBENCH: nvfuser

Differential Revision: [D38043938](https://our.internmc.facebook.com/intern/diff/D38043938)

[ghstack-poisoned]
jjsjann123 added a commit that referenced this pull request Jul 27, 2022
Syncing nvfuser devel branch to upstream master. https://github.com/csarofeen/pytorch/

Code changes includes:

- codegen improvements:
  1. Indexing refactor -> Remove reference tensor in predicate indexing logic
  2. MMA Rfactor support for cross-warp and cross-CTA split on K dimension
  3. Grouping grid allreduces across iterations
  4. Swizzle op formulation for non-affine swizzles
  5. Use scheduler_utils to cache inputs and outputs in schedulePointwise
- scheduler refactor
  1. New compute at interface
- transformation propagation refactor on MaxInfoSpanningTree
  1. Added sibling path that is required to generate consistent replay for some cases where `MaxInfoSpanningTree` is used with a selector.
  2. Optimization to skip Transform propagator
  3. SpanningTreePrinter for debugging
- parser update
  1. Fixes `div`
  2. Added `_to_copy`
  3. Broadcast in dim with expand to support expanding to concrete size
  4. Dropout prob extremal patch
- executor patch on caching strides for output allocation

Squashed commits to WAR github API
Commits that's actually in this PR from the devel branch:

```
3b87896 Fix allocation of work buffers and `fused_reduction::ParallelReduce` with unswitch (#1818)
4cae122 schedulePointwise cleanup: - computeAt + InlinePropagator (#1815)
3df9742 Use scheduler_utils to cache inputs and outputs in schedulePointwise (#1811)
03180aa improve broadcast resolution (#1792)
bee6c69 bug fix (#1819)
4413c8f Support PYTORCH_NVFUSER_DUMP=transform_propagator (#1812)
de6b7ca Fix negative position in InlinePropagator (#1813)
10a996c Remove redundant check in schedulePointwise (#1810)
acd5ed4 Swizzle op formulation for non-affine swizzles (#1441)
3ed8330 Kernel args patch to show zero_init buffer (#1809)
037a75a Dropout prob extremal patch (#1804)
282c429 spam nvrtc options (#1783)
3ba6a5f Broadcast in dim with expand (#1794)
fd4be12 remove dead indexing code (#1806)
fa4e6a4 Check siblings in getMaxPosAll (#1805)
025c840 Grouping grid allreduces across iterations (#1755)
37c579e Temporarily disable test requring large shared memory. (#1802)
5f375d0 More cleanup on InlinePropagator (#1800)
8d384da Indexing refactor stage 2 : Remove reference tensor in predicate indexing logic (#1784)
f008140 MMA Rfactor support for cross-warp and cross-CTA split on K dimension (#1554)
76b3cca Add parsing support for `_to_copy` to handle AMP casts. (#1756)
ef04f6c Coding style cleanups (#1798)
38c7f3c InlinePropagator please don't replay (#1797)
3f2c263 validateDomain in TransformPropagator (#1796)
c077085 Use TransformPropagatorWithCheck in many tests (#1795)
d0d0908 Some further cleanup for the new computeAt interface (#1793)
45f5203 Fix TransformReplay::getMatchedLeafPosWithoutReplay* (#1791)
28cbaf9 New compute at interface (#1743)
635ebfc Add SpanningTreePrinter (#1786)
59f3c32 Output allocate patch (#1790)
fe93bf5 Transform propagator skip replay when possible (#1782)
ebf23a5 Fix isIntegralType error msg (#1789)
0c82ecf Disable register reuse across serial broadcast ops (#1787)
33a824d Adding sibling path for MaxInfoSpanningTree (#1776)
86f46aa Fix div(Val, TensorView) (#1778)
d3de227 Fix FusionMaxRootDomainInfoSpanningTreePrintTwice_CUDA (#1781)
ecc7a87 Extend mma dimension and layout checking to support strided batched matmul and tensor contractions (#1761)
```

RUN_TORCHBENCH: nvfuser

Differential Revision: [D38043938](https://our.internmc.facebook.com/intern/diff/D38043938)

[ghstack-poisoned]
jjsjann123 added a commit that referenced this pull request Jul 27, 2022
Syncing nvfuser devel branch to upstream master. https://github.com/csarofeen/pytorch/

Code changes includes:

- codegen improvements:
  1. Indexing refactor -> Remove reference tensor in predicate indexing logic
  2. MMA Rfactor support for cross-warp and cross-CTA split on K dimension
  3. Grouping grid allreduces across iterations
  4. Swizzle op formulation for non-affine swizzles
  5. Use scheduler_utils to cache inputs and outputs in schedulePointwise
- scheduler refactor
  1. New compute at interface
- transformation propagation refactor on MaxInfoSpanningTree
  1. Added sibling path that is required to generate consistent replay for some cases where `MaxInfoSpanningTree` is used with a selector.
  2. Optimization to skip Transform propagator
  3. SpanningTreePrinter for debugging
- parser update
  1. Fixes `div`
  2. Added `_to_copy`
  3. Broadcast in dim with expand to support expanding to concrete size
  4. Dropout prob extremal patch
- executor patch on caching strides for output allocation

Squashed commits to WAR github API
Commits that's actually in this PR from the devel branch:

```
3b87896 Fix allocation of work buffers and `fused_reduction::ParallelReduce` with unswitch (#1818)
4cae122 schedulePointwise cleanup: - computeAt + InlinePropagator (#1815)
3df9742 Use scheduler_utils to cache inputs and outputs in schedulePointwise (#1811)
03180aa improve broadcast resolution (#1792)
bee6c69 bug fix (#1819)
4413c8f Support PYTORCH_NVFUSER_DUMP=transform_propagator (#1812)
de6b7ca Fix negative position in InlinePropagator (#1813)
10a996c Remove redundant check in schedulePointwise (#1810)
acd5ed4 Swizzle op formulation for non-affine swizzles (#1441)
3ed8330 Kernel args patch to show zero_init buffer (#1809)
037a75a Dropout prob extremal patch (#1804)
282c429 spam nvrtc options (#1783)
3ba6a5f Broadcast in dim with expand (#1794)
fd4be12 remove dead indexing code (#1806)
fa4e6a4 Check siblings in getMaxPosAll (#1805)
025c840 Grouping grid allreduces across iterations (#1755)
37c579e Temporarily disable test requring large shared memory. (#1802)
5f375d0 More cleanup on InlinePropagator (#1800)
8d384da Indexing refactor stage 2 : Remove reference tensor in predicate indexing logic (#1784)
f008140 MMA Rfactor support for cross-warp and cross-CTA split on K dimension (#1554)
76b3cca Add parsing support for `_to_copy` to handle AMP casts. (#1756)
ef04f6c Coding style cleanups (#1798)
38c7f3c InlinePropagator please don't replay (#1797)
3f2c263 validateDomain in TransformPropagator (#1796)
c077085 Use TransformPropagatorWithCheck in many tests (#1795)
d0d0908 Some further cleanup for the new computeAt interface (#1793)
45f5203 Fix TransformReplay::getMatchedLeafPosWithoutReplay* (#1791)
28cbaf9 New compute at interface (#1743)
635ebfc Add SpanningTreePrinter (#1786)
59f3c32 Output allocate patch (#1790)
fe93bf5 Transform propagator skip replay when possible (#1782)
ebf23a5 Fix isIntegralType error msg (#1789)
0c82ecf Disable register reuse across serial broadcast ops (#1787)
33a824d Adding sibling path for MaxInfoSpanningTree (#1776)
86f46aa Fix div(Val, TensorView) (#1778)
d3de227 Fix FusionMaxRootDomainInfoSpanningTreePrintTwice_CUDA (#1781)
ecc7a87 Extend mma dimension and layout checking to support strided batched matmul and tensor contractions (#1761)
```

RUN_TORCHBENCH: nvfuser

ghstack-source-id: a74f653
Pull Request resolved: #81861
pytorchmergebot pushed a commit that referenced this pull request Jul 28, 2022
Syncing nvfuser devel branch to upstream master. https://github.com/csarofeen/pytorch/

Code changes includes:

- codegen improvements:
  1. Indexing refactor -> Remove reference tensor in predicate indexing logic
  2. MMA Rfactor support for cross-warp and cross-CTA split on K dimension
  3. Grouping grid allreduces across iterations
  4. Swizzle op formulation for non-affine swizzles
  5. Use scheduler_utils to cache inputs and outputs in schedulePointwise
- scheduler refactor
  1. New compute at interface
- transformation propagation refactor on MaxInfoSpanningTree
  1. Added sibling path that is required to generate consistent replay for some cases where `MaxInfoSpanningTree` is used with a selector.
  2. Optimization to skip Transform propagator
  3. SpanningTreePrinter for debugging
- parser update
  1. Fixes `div`
  2. Added `_to_copy`
  3. Broadcast in dim with expand to support expanding to concrete size
  4. Dropout prob extremal patch
- executor patch on caching strides for output allocation

Squashed commits to WAR github API
Commits that's actually in this PR from the devel branch:

```
3b87896 Fix allocation of work buffers and `fused_reduction::ParallelReduce` with unswitch (#1818)
4cae122 schedulePointwise cleanup: - computeAt + InlinePropagator (#1815)
3df9742 Use scheduler_utils to cache inputs and outputs in schedulePointwise (#1811)
03180aa improve broadcast resolution (#1792)
bee6c69 bug fix (#1819)
4413c8f Support PYTORCH_NVFUSER_DUMP=transform_propagator (#1812)
de6b7ca Fix negative position in InlinePropagator (#1813)
10a996c Remove redundant check in schedulePointwise (#1810)
acd5ed4 Swizzle op formulation for non-affine swizzles (#1441)
3ed8330 Kernel args patch to show zero_init buffer (#1809)
037a75a Dropout prob extremal patch (#1804)
282c429 spam nvrtc options (#1783)
3ba6a5f Broadcast in dim with expand (#1794)
fd4be12 remove dead indexing code (#1806)
fa4e6a4 Check siblings in getMaxPosAll (#1805)
025c840 Grouping grid allreduces across iterations (#1755)
37c579e Temporarily disable test requring large shared memory. (#1802)
5f375d0 More cleanup on InlinePropagator (#1800)
8d384da Indexing refactor stage 2 : Remove reference tensor in predicate indexing logic (#1784)
f008140 MMA Rfactor support for cross-warp and cross-CTA split on K dimension (#1554)
76b3cca Add parsing support for `_to_copy` to handle AMP casts. (#1756)
ef04f6c Coding style cleanups (#1798)
38c7f3c InlinePropagator please don't replay (#1797)
3f2c263 validateDomain in TransformPropagator (#1796)
c077085 Use TransformPropagatorWithCheck in many tests (#1795)
d0d0908 Some further cleanup for the new computeAt interface (#1793)
45f5203 Fix TransformReplay::getMatchedLeafPosWithoutReplay* (#1791)
28cbaf9 New compute at interface (#1743)
635ebfc Add SpanningTreePrinter (#1786)
59f3c32 Output allocate patch (#1790)
fe93bf5 Transform propagator skip replay when possible (#1782)
ebf23a5 Fix isIntegralType error msg (#1789)
0c82ecf Disable register reuse across serial broadcast ops (#1787)
33a824d Adding sibling path for MaxInfoSpanningTree (#1776)
86f46aa Fix div(Val, TensorView) (#1778)
d3de227 Fix FusionMaxRootDomainInfoSpanningTreePrintTwice_CUDA (#1781)
ecc7a87 Extend mma dimension and layout checking to support strided batched matmul and tensor contractions (#1761)
```

RUN_TORCHBENCH: nvfuser

Differential Revision: [D38043938](https://our.internmc.facebook.com/intern/diff/D38043938)
Pull Request resolved: #81861
Approved by: https://github.com/davidberard98
facebook-github-bot pushed a commit that referenced this pull request Jul 28, 2022
Summary:
Pull Request resolved: #81861

Syncing nvfuser devel branch to upstream master. https://github.com/csarofeen/pytorch/

Code changes includes:

- codegen improvements:
  1. Indexing refactor -> Remove reference tensor in predicate indexing logic
  2. MMA Rfactor support for cross-warp and cross-CTA split on K dimension
  3. Grouping grid allreduces across iterations
  4. Swizzle op formulation for non-affine swizzles
  5. Use scheduler_utils to cache inputs and outputs in schedulePointwise
- scheduler refactor
  1. New compute at interface
- transformation propagation refactor on MaxInfoSpanningTree
  1. Added sibling path that is required to generate consistent replay for some cases where `MaxInfoSpanningTree` is used with a selector.
  2. Optimization to skip Transform propagator
  3. SpanningTreePrinter for debugging
- parser update
  1. Fixes `div`
  2. Added `_to_copy`
  3. Broadcast in dim with expand to support expanding to concrete size
  4. Dropout prob extremal patch
- executor patch on caching strides for output allocation

Squashed commits to WAR github API
Commits that's actually in this PR from the devel branch:

```
3b87896 Fix allocation of work buffers and `fused_reduction::ParallelReduce` with unswitch (#1818)
4cae122 schedulePointwise cleanup: - computeAt + InlinePropagator (#1815)
3df9742 Use scheduler_utils to cache inputs and outputs in schedulePointwise (#1811)
03180aa improve broadcast resolution (#1792)
bee6c69 bug fix (#1819)
4413c8f Support PYTORCH_NVFUSER_DUMP=transform_propagator (#1812)
de6b7ca Fix negative position in InlinePropagator (#1813)
10a996c Remove redundant check in schedulePointwise (#1810)
acd5ed4 Swizzle op formulation for non-affine swizzles (#1441)
3ed8330 Kernel args patch to show zero_init buffer (#1809)
037a75a Dropout prob extremal patch (#1804)
282c429 spam nvrtc options (#1783)
3ba6a5f Broadcast in dim with expand (#1794)
fd4be12 remove dead indexing code (#1806)
fa4e6a4 Check siblings in getMaxPosAll (#1805)
025c840 Grouping grid allreduces across iterations (#1755)
37c579e Temporarily disable test requring large shared memory. (#1802)
5f375d0 More cleanup on InlinePropagator (#1800)
8d384da Indexing refactor stage 2 : Remove reference tensor in predicate indexing logic (#1784)
f008140 MMA Rfactor support for cross-warp and cross-CTA split on K dimension (#1554)
76b3cca Add parsing support for `_to_copy` to handle AMP casts. (#1756)
ef04f6c Coding style cleanups (#1798)
38c7f3c InlinePropagator please don't replay (#1797)
3f2c263 validateDomain in TransformPropagator (#1796)
c077085 Use TransformPropagatorWithCheck in many tests (#1795)
d0d0908 Some further cleanup for the new computeAt interface (#1793)
45f5203 Fix TransformReplay::getMatchedLeafPosWithoutReplay* (#1791)
28cbaf9 New compute at interface (#1743)
635ebfc Add SpanningTreePrinter (#1786)
59f3c32 Output allocate patch (#1790)
fe93bf5 Transform propagator skip replay when possible (#1782)
ebf23a5 Fix isIntegralType error msg (#1789)
0c82ecf Disable register reuse across serial broadcast ops (#1787)
33a824d Adding sibling path for MaxInfoSpanningTree (#1776)
86f46aa Fix div(Val, TensorView) (#1778)
d3de227 Fix FusionMaxRootDomainInfoSpanningTreePrintTwice_CUDA (#1781)
ecc7a87 Extend mma dimension and layout checking to support strided batched matmul and tensor contractions (#1761)
```

RUN_TORCHBENCH: nvfuser

Test Plan: Imported from OSS

Reviewed By: samdow

Differential Revision: D38043938

Pulled By: davidberard98

fbshipit-source-id: b94245f83dab6faee31e0c154d3b969bddeb3d47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants