Skip to content

Conversation

@ezyang
Copy link
Contributor

@ezyang ezyang commented Jun 15, 2017

This commit series turns on forward trace construction in PyTorch. These forward traces are automatically saved as you execute your Torch code, and then can be rerun in the future with different input values to do a different computation. This is a PROOF OF CONCEPT because you don't actually want tracing to be turned on all the time.

Some primary characteristics of the implementation:

  • The first four commits are just some refactoring commits, I'll submit them separately.
  • The basic IR model is a multigraph, where each Node represents a computation, taking in some number of tensors, and producing some number of tensors. An Output represents a particular output of a function. Unlike the current backwards representation, the pointers in this representation go "the opposite way." Input nodes are identified by pointer equality; every Variable is associated with a unique input node which lets you get a handle into the parameters of the graph.
  • The IR is exposed at the Python level using proxy objects; this means that if you don't access the trace from Python, Python objects are never materialized. Note that PyNode operators themselves contain a pointer to the Python class which implements their forward pass, but this cannot cause cycles to arise.
  • The IR today uses shared pointers. When we add "regions" to trace over, it should become possible to region allocate the IR and then remove the uses of the shared pointers, to make things faster.
  • The execution engine is a simple but inefficient memoized, recursive interpreter. In principle we can parallelize it in the same way today's engine is done, but I have not done this yet. The interpreter basically calls into THPFunction_apply, doing the dumbest possible thing (constructing Python objects to pass in, and then destructing Python objects on the way out.)

Where it's going from here:

  • Turn tracing on/off by the user
  • Add frontend support for "dumb" just-in-time compilation of operators, with no correctness checking
  • Implement a peephole rewriter on the resulting trace, allowing us to apply fusion
  • Port backward traces to use the new IR format
  • Bikeshed the names Output/Node
  • Better trace pretty-printing capabilities and a textual language design
  • Figure out how to get speed increases even when operators are implemented in Python

CC @lantiga @apaszke @zdevito

Copy link
Contributor

@apaszke apaszke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if there's really no way to reuse the Engine we have for backward. It seems that all interpret does sth very similar, but is much simpler (and less efficient)

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

setup.py Outdated

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

@soumith soumith changed the title Proof of concept: Forward tracing in PyTorch [WIP] Proof of concept: Forward tracing in PyTorch Jun 22, 2017
ezyang and others added 24 commits July 16, 2017 18:32
Also, add a new trace_fn field to attach forward IR to Variables.

Signed-off-by: Edward Z. Yang <[email protected]>
Simple test:

  import torch
  from torch.autograd import Variable
  import torch._C as _C

  x = Variable(torch.Tensor([4]), requires_grad=True)
  y = Variable(torch.Tensor([7]), requires_grad=True)
  z = x * y
  z.sum().backward()

  print(x.grad)
  print(y.grad)

  x.data[0] = 2
  y.data[0] = 3

  (z,) = z._execution_engine.run_forward((x, y), (z,))
  z.sum().backward()

  print(x.grad)
  print(y.grad)

Signed-off-by: Edward Z. Yang <[email protected]>
Signed-off-by: Edward Z. Yang <[email protected]>
Signed-off-by: Edward Z. Yang <[email protected]>
Previously, our AST was a DAG, where shared Nodes indicated a computation
should be reused.  This commit rewrites the IR into a new functional
representation which represents sharing explicitly using variable
bindings.

We offer a few justifications for this new style:

1. The new representation is not all that different from the
old one; it is about as easy to construct, and the lack of an
explicit graph doesn't negatively impact our ability to interpret
the graph, since we've chosen, as a matter of design, to NOT have
the IR participate in the actual execution of a graph.

2. The new let-binding representation has an implicit ordering,
which we can use to conveniently keep track of the original order
the trace showed up as.  This automatically gives us a topsort,
and gives us an easier to read textual representation of our
IR:

  %14 = Embedding %11, %0, -1, None, 2, False, False
  %15 = Dropout %14, 0.2, True, False
  %16 = Index %12, 0
  %17 = Index %12, 1
  %18 = Index %13, 0
  %19 = Index %13, 1
  %20 = Index %15, 0
  %21 = Linear %20, %1, %3
  %22 = Linear %16, %2, %4

3. It moves us closer to a Futhark style language
(http://futhark-lang.org/publications/pldi17.pdf).

Major aspects of the diff

- Node is replaced with Expr and Arg, a pair of mutually recursive
  structures which represent our new language.  In BNF, the language
  looks like this:

    a ::= c | %i
    e ::= %i, ... = e
        | PyOp e, ...
        | Ret %i, ...

  Technically, Ret is not actually a return (no control flow is involved),
  it just tuples up a series of tensors (identified by variables).

  One important invariant is that locals are always tensors; they
  are never constants (this is asymmetric with Args.)

- Arguments support Python constants.  This is an important piece because
  many operators take extra Python literals like integers and tuples in
  order to specify extra parameters about how an operator operates.  Adding
  this was essential to getting word_language_model to work.

- As both Expr and Arg have multiple variants, there is new infrastructure
  for doing case on the variants using ExprVisitor and ArgVisitor.  The
  strategy here is adapted from WebAssembly's visitors, although we have
  generalized to permit arbitrary argument forwarding, which is necessary
  to support tail-recursive visitor calls.  TCO is important because our
  interpreter may recurse arbitrarily deep into a stack of nested lets.
  If users wish, they can also manually case on the type tag.

- Tracing is now turned on and off using _tracer_enter/_tracer_exit in
  torch._C.  _tracer_enter accepts a list of variables which are to be
  treated as arguments; _tracer_exit accepts the list of traced variables
  which should be returned when you reexecute the trace, and returns
  the trace expression which can be reexecuted.  GlobalTracingState
  is a global variable which tracks whether or not we are tracing or not.

- You use run_forward to execute a trace on some set of parameters.

- When under tracing, variables keep track, via trace_local, what the
  name of their variables in the IR are.

Here is a simple runner which leaks memory but can be used to JIT models:

  import torch.autograd.function as F
  import torch._C

  def jit(model):
      import types
      real_forward = model.forward
      def forward(self, *args):
          def flatten(x):
              return tuple(F._iter_variables(x))
          if not hasattr(self, "saved_trace"):
              torch._C._tracer_enter(tuple(self.parameters()) + flatten(args))
              out = real_forward(*args)
              self.saved_trace = torch._C._tracer_exit(flatten(out))
              self.saved_outs = out
              return out
          else:
              flat_out = Variable._execution_engine.run_forward(self.saved_trace, tuple(self.parameters()) + flatten(args))
              return F._unflatten(flat_out, self.saved_outs)

Major problems:

- Sanity checking is spotty at best, especially when users pass in variables.

- The interpreter leaks tensor memory from the store.  When we add back def-use
  we should be able to deallocate tensors as soon as we know they are no longer
  necessary.

- The interpreter needs to reach feature parity with the old execution engine.
  From there, we need to see if backwards can be subsumed as well.

- I still have no confidence in having memory managed everything correctly.
  This requires a close look.

- Rather than return an *open* expression as a trace, we should return a
  *lambda* instead, which knows about how many formal parameters it
  requires.

- The IR is not introspectable from Python at the moment, but this is simply a
  matter of implementing all the binding code.

- The tracer is NOT reentrant (you can't trace while you're inside a trace.)
  Furthermore, no sanity checking is done if you try to incorrectly reuse
  things from one trace in another.

Signed-off-by: Edward Z. Yang <[email protected]>
Signed-off-by: Edward Z. Yang <[email protected]>
Signed-off-by: Edward Z. Yang <[email protected]>
Signed-off-by: Edward Z. Yang <[email protected]>
Although ANF style developments traditionally stratifies syntactic
classes into atomic (Arg) and complex (Expr) expressions, where
atomic expressions could be variables, constants or lambdas, Zach has
successfully convinced me that we should do away with the variant here and
always require arguments to be variables.  There are a few reasons for
this:

1) Tensor constants, not currently supported, could be modeled using a
"Constant" instruction, removing the need for them to be representable
directly inline.  An inline constant is marginally more convenient
for peephole optimizations, but since we have gone full ANF, we are going
to need to be able to see across def-uses in any case, and it is not
too much worse to need to handle constants this way.  By the way,
Swift Intermediate Language also made a similar choice, see
the slide on "Literal Instructions" in
http://llvm.org/devmtg/2015-10/slides/GroffLattner-SILHighLevelIR.pdf

2) Scalar constants, which are quite important for passing non-tensor
arguments to Python operators, are now stored out-of-band as NON
first-class values.  This more closely matches the ToffeeIR design,
and makes it clear what parameters are "first class" (tensors only)
and which ones are not.  However, we need to be able to unswizzle
the separate scalar/tensor lists into a unified list in the correct
format; this is what PyFunctionCConv is for.

Also, Locals got renamed into Tuple.

Signed-off-by: Edward Z. Yang <[email protected]>
Signed-off-by: Edward Z. Yang <[email protected]>
This prevents nested lets, which are not allowed in ANF.  We
basically have SSA now.

There's some niftiness with the visitor returning a lambda which
then gets fed the actual argument. I like it.

Signed-off-by: Edward Z. Yang <[email protected]>
Signed-off-by: Edward Z. Yang <[email protected]>
It is not an /expression/ we trace, but it is a /graph/: that is,
a closed expression which knows its parameters.  Knowing the list
of parameters is helpful and helps remove a hack when interpreting.

Signed-off-by: Edward Z. Yang <[email protected]>
apaszke and others added 9 commits July 16, 2017 19:38
Signed-off-by: Edward Z. Yang <[email protected]>
Now it gets initialized during the constructor.  This results
in more boilerplate but is conceptually more correct, and solves
an assert failure.

Signed-off-by: Edward Z. Yang <[email protected]>
Add an assert wrapper for easy porting.
@ezyang ezyang closed this Jul 21, 2017
@ryanleary
Copy link

@ezyang are you still working on this?

@ezyang
Copy link
Contributor Author

ezyang commented Jul 24, 2017

Yep. Dev is still proceeding apace in the jit branch.

jjsjann123 pushed a commit to jjsjann123/pytorch that referenced this pull request Jul 15, 2022
jjsjann123 added a commit that referenced this pull request Jul 21, 2022
Syncing nvfuser devel branch to upstream master. https://github.com/csarofeen/pytorch/

Code changes includes:

- codegen improvements:
  1. Indexing refactor -> Remove reference tensor in predicate indexing logic
  2. MMA Rfactor support for cross-warp and cross-CTA split on K dimension
  3. Grouping grid allreduces across iterations
  4. Swizzle op formulation for non-affine swizzles
  5. Use scheduler_utils to cache inputs and outputs in schedulePointwise
- scheduler refactor
  1. New compute at interface
- transformation propagation refactor on MaxInfoSpanningTree
  1. Added sibling path that is required to generate consistent replay for some cases where `MaxInfoSpanningTree` is used with a selector.
  2. Optimization to skip Transform propagator
  3. SpanningTreePrinter for debugging
- parser update
  1. Fixes `div`
  2. Added `_to_copy`
  3. Broadcast in dim with expand to support expanding to concrete size
  4. Dropout prob extremal patch
- executor patch on caching strides for output allocation

Squashed commits to WAR github API
Commits that's actually in this PR from the devel branch:

```
3b87896 Fix allocation of work buffers and `fused_reduction::ParallelReduce` with unswitch (#1818)
4cae122 schedulePointwise cleanup: - computeAt + InlinePropagator (#1815)
3df9742 Use scheduler_utils to cache inputs and outputs in schedulePointwise (#1811)
03180aa improve broadcast resolution (#1792)
bee6c69 bug fix (#1819)
4413c8f Support PYTORCH_NVFUSER_DUMP=transform_propagator (#1812)
de6b7ca Fix negative position in InlinePropagator (#1813)
10a996c Remove redundant check in schedulePointwise (#1810)
acd5ed4 Swizzle op formulation for non-affine swizzles (#1441)
3ed8330 Kernel args patch to show zero_init buffer (#1809)
037a75a Dropout prob extremal patch (#1804)
282c429 spam nvrtc options (#1783)
3ba6a5f Broadcast in dim with expand (#1794)
fd4be12 remove dead indexing code (#1806)
fa4e6a4 Check siblings in getMaxPosAll (#1805)
025c840 Grouping grid allreduces across iterations (#1755)
37c579e Temporarily disable test requring large shared memory. (#1802)
5f375d0 More cleanup on InlinePropagator (#1800)
8d384da Indexing refactor stage 2 : Remove reference tensor in predicate indexing logic (#1784)
f008140 MMA Rfactor support for cross-warp and cross-CTA split on K dimension (#1554)
76b3cca Add parsing support for `_to_copy` to handle AMP casts. (#1756)
ef04f6c Coding style cleanups (#1798)
38c7f3c InlinePropagator please don't replay (#1797)
3f2c263 validateDomain in TransformPropagator (#1796)
c077085 Use TransformPropagatorWithCheck in many tests (#1795)
d0d0908 Some further cleanup for the new computeAt interface (#1793)
45f5203 Fix TransformReplay::getMatchedLeafPosWithoutReplay* (#1791)
28cbaf9 New compute at interface (#1743)
635ebfc Add SpanningTreePrinter (#1786)
59f3c32 Output allocate patch (#1790)
fe93bf5 Transform propagator skip replay when possible (#1782)
ebf23a5 Fix isIntegralType error msg (#1789)
0c82ecf Disable register reuse across serial broadcast ops (#1787)
33a824d Adding sibling path for MaxInfoSpanningTree (#1776)
86f46aa Fix div(Val, TensorView) (#1778)
d3de227 Fix FusionMaxRootDomainInfoSpanningTreePrintTwice_CUDA (#1781)
ecc7a87 Extend mma dimension and layout checking to support strided batched matmul and tensor contractions (#1761)
```

[ghstack-poisoned]
jjsjann123 added a commit that referenced this pull request Jul 21, 2022
Syncing nvfuser devel branch to upstream master. https://github.com/csarofeen/pytorch/

Code changes includes:

- codegen improvements:
  1. Indexing refactor -> Remove reference tensor in predicate indexing logic
  2. MMA Rfactor support for cross-warp and cross-CTA split on K dimension
  3. Grouping grid allreduces across iterations
  4. Swizzle op formulation for non-affine swizzles
  5. Use scheduler_utils to cache inputs and outputs in schedulePointwise
- scheduler refactor
  1. New compute at interface
- transformation propagation refactor on MaxInfoSpanningTree
  1. Added sibling path that is required to generate consistent replay for some cases where `MaxInfoSpanningTree` is used with a selector.
  2. Optimization to skip Transform propagator
  3. SpanningTreePrinter for debugging
- parser update
  1. Fixes `div`
  2. Added `_to_copy`
  3. Broadcast in dim with expand to support expanding to concrete size
  4. Dropout prob extremal patch
- executor patch on caching strides for output allocation

Squashed commits to WAR github API
Commits that's actually in this PR from the devel branch:

```
3b87896 Fix allocation of work buffers and `fused_reduction::ParallelReduce` with unswitch (#1818)
4cae122 schedulePointwise cleanup: - computeAt + InlinePropagator (#1815)
3df9742 Use scheduler_utils to cache inputs and outputs in schedulePointwise (#1811)
03180aa improve broadcast resolution (#1792)
bee6c69 bug fix (#1819)
4413c8f Support PYTORCH_NVFUSER_DUMP=transform_propagator (#1812)
de6b7ca Fix negative position in InlinePropagator (#1813)
10a996c Remove redundant check in schedulePointwise (#1810)
acd5ed4 Swizzle op formulation for non-affine swizzles (#1441)
3ed8330 Kernel args patch to show zero_init buffer (#1809)
037a75a Dropout prob extremal patch (#1804)
282c429 spam nvrtc options (#1783)
3ba6a5f Broadcast in dim with expand (#1794)
fd4be12 remove dead indexing code (#1806)
fa4e6a4 Check siblings in getMaxPosAll (#1805)
025c840 Grouping grid allreduces across iterations (#1755)
37c579e Temporarily disable test requring large shared memory. (#1802)
5f375d0 More cleanup on InlinePropagator (#1800)
8d384da Indexing refactor stage 2 : Remove reference tensor in predicate indexing logic (#1784)
f008140 MMA Rfactor support for cross-warp and cross-CTA split on K dimension (#1554)
76b3cca Add parsing support for `_to_copy` to handle AMP casts. (#1756)
ef04f6c Coding style cleanups (#1798)
38c7f3c InlinePropagator please don't replay (#1797)
3f2c263 validateDomain in TransformPropagator (#1796)
c077085 Use TransformPropagatorWithCheck in many tests (#1795)
d0d0908 Some further cleanup for the new computeAt interface (#1793)
45f5203 Fix TransformReplay::getMatchedLeafPosWithoutReplay* (#1791)
28cbaf9 New compute at interface (#1743)
635ebfc Add SpanningTreePrinter (#1786)
59f3c32 Output allocate patch (#1790)
fe93bf5 Transform propagator skip replay when possible (#1782)
ebf23a5 Fix isIntegralType error msg (#1789)
0c82ecf Disable register reuse across serial broadcast ops (#1787)
33a824d Adding sibling path for MaxInfoSpanningTree (#1776)
86f46aa Fix div(Val, TensorView) (#1778)
d3de227 Fix FusionMaxRootDomainInfoSpanningTreePrintTwice_CUDA (#1781)
ecc7a87 Extend mma dimension and layout checking to support strided batched matmul and tensor contractions (#1761)
```

ghstack-source-id: f24793f
Pull Request resolved: #81861
jjsjann123 added a commit that referenced this pull request Jul 21, 2022
Syncing nvfuser devel branch to upstream master. https://github.com/csarofeen/pytorch/

Code changes includes:

- codegen improvements:
  1. Indexing refactor -> Remove reference tensor in predicate indexing logic
  2. MMA Rfactor support for cross-warp and cross-CTA split on K dimension
  3. Grouping grid allreduces across iterations
  4. Swizzle op formulation for non-affine swizzles
  5. Use scheduler_utils to cache inputs and outputs in schedulePointwise
- scheduler refactor
  1. New compute at interface
- transformation propagation refactor on MaxInfoSpanningTree
  1. Added sibling path that is required to generate consistent replay for some cases where `MaxInfoSpanningTree` is used with a selector.
  2. Optimization to skip Transform propagator
  3. SpanningTreePrinter for debugging
- parser update
  1. Fixes `div`
  2. Added `_to_copy`
  3. Broadcast in dim with expand to support expanding to concrete size
  4. Dropout prob extremal patch
- executor patch on caching strides for output allocation

Squashed commits to WAR github API
Commits that's actually in this PR from the devel branch:

```
3b87896 Fix allocation of work buffers and `fused_reduction::ParallelReduce` with unswitch (#1818)
4cae122 schedulePointwise cleanup: - computeAt + InlinePropagator (#1815)
3df9742 Use scheduler_utils to cache inputs and outputs in schedulePointwise (#1811)
03180aa improve broadcast resolution (#1792)
bee6c69 bug fix (#1819)
4413c8f Support PYTORCH_NVFUSER_DUMP=transform_propagator (#1812)
de6b7ca Fix negative position in InlinePropagator (#1813)
10a996c Remove redundant check in schedulePointwise (#1810)
acd5ed4 Swizzle op formulation for non-affine swizzles (#1441)
3ed8330 Kernel args patch to show zero_init buffer (#1809)
037a75a Dropout prob extremal patch (#1804)
282c429 spam nvrtc options (#1783)
3ba6a5f Broadcast in dim with expand (#1794)
fd4be12 remove dead indexing code (#1806)
fa4e6a4 Check siblings in getMaxPosAll (#1805)
025c840 Grouping grid allreduces across iterations (#1755)
37c579e Temporarily disable test requring large shared memory. (#1802)
5f375d0 More cleanup on InlinePropagator (#1800)
8d384da Indexing refactor stage 2 : Remove reference tensor in predicate indexing logic (#1784)
f008140 MMA Rfactor support for cross-warp and cross-CTA split on K dimension (#1554)
76b3cca Add parsing support for `_to_copy` to handle AMP casts. (#1756)
ef04f6c Coding style cleanups (#1798)
38c7f3c InlinePropagator please don't replay (#1797)
3f2c263 validateDomain in TransformPropagator (#1796)
c077085 Use TransformPropagatorWithCheck in many tests (#1795)
d0d0908 Some further cleanup for the new computeAt interface (#1793)
45f5203 Fix TransformReplay::getMatchedLeafPosWithoutReplay* (#1791)
28cbaf9 New compute at interface (#1743)
635ebfc Add SpanningTreePrinter (#1786)
59f3c32 Output allocate patch (#1790)
fe93bf5 Transform propagator skip replay when possible (#1782)
ebf23a5 Fix isIntegralType error msg (#1789)
0c82ecf Disable register reuse across serial broadcast ops (#1787)
33a824d Adding sibling path for MaxInfoSpanningTree (#1776)
86f46aa Fix div(Val, TensorView) (#1778)
d3de227 Fix FusionMaxRootDomainInfoSpanningTreePrintTwice_CUDA (#1781)
ecc7a87 Extend mma dimension and layout checking to support strided batched matmul and tensor contractions (#1761)
```

RUN_TORCHBENCH: nvfuser

Differential Revision: [D38043938](https://our.internmc.facebook.com/intern/diff/D38043938)

[ghstack-poisoned]
jjsjann123 added a commit that referenced this pull request Jul 21, 2022
Syncing nvfuser devel branch to upstream master. https://github.com/csarofeen/pytorch/

Code changes includes:

- codegen improvements:
  1. Indexing refactor -> Remove reference tensor in predicate indexing logic
  2. MMA Rfactor support for cross-warp and cross-CTA split on K dimension
  3. Grouping grid allreduces across iterations
  4. Swizzle op formulation for non-affine swizzles
  5. Use scheduler_utils to cache inputs and outputs in schedulePointwise
- scheduler refactor
  1. New compute at interface
- transformation propagation refactor on MaxInfoSpanningTree
  1. Added sibling path that is required to generate consistent replay for some cases where `MaxInfoSpanningTree` is used with a selector.
  2. Optimization to skip Transform propagator
  3. SpanningTreePrinter for debugging
- parser update
  1. Fixes `div`
  2. Added `_to_copy`
  3. Broadcast in dim with expand to support expanding to concrete size
  4. Dropout prob extremal patch
- executor patch on caching strides for output allocation

Squashed commits to WAR github API
Commits that's actually in this PR from the devel branch:

```
3b87896 Fix allocation of work buffers and `fused_reduction::ParallelReduce` with unswitch (#1818)
4cae122 schedulePointwise cleanup: - computeAt + InlinePropagator (#1815)
3df9742 Use scheduler_utils to cache inputs and outputs in schedulePointwise (#1811)
03180aa improve broadcast resolution (#1792)
bee6c69 bug fix (#1819)
4413c8f Support PYTORCH_NVFUSER_DUMP=transform_propagator (#1812)
de6b7ca Fix negative position in InlinePropagator (#1813)
10a996c Remove redundant check in schedulePointwise (#1810)
acd5ed4 Swizzle op formulation for non-affine swizzles (#1441)
3ed8330 Kernel args patch to show zero_init buffer (#1809)
037a75a Dropout prob extremal patch (#1804)
282c429 spam nvrtc options (#1783)
3ba6a5f Broadcast in dim with expand (#1794)
fd4be12 remove dead indexing code (#1806)
fa4e6a4 Check siblings in getMaxPosAll (#1805)
025c840 Grouping grid allreduces across iterations (#1755)
37c579e Temporarily disable test requring large shared memory. (#1802)
5f375d0 More cleanup on InlinePropagator (#1800)
8d384da Indexing refactor stage 2 : Remove reference tensor in predicate indexing logic (#1784)
f008140 MMA Rfactor support for cross-warp and cross-CTA split on K dimension (#1554)
76b3cca Add parsing support for `_to_copy` to handle AMP casts. (#1756)
ef04f6c Coding style cleanups (#1798)
38c7f3c InlinePropagator please don't replay (#1797)
3f2c263 validateDomain in TransformPropagator (#1796)
c077085 Use TransformPropagatorWithCheck in many tests (#1795)
d0d0908 Some further cleanup for the new computeAt interface (#1793)
45f5203 Fix TransformReplay::getMatchedLeafPosWithoutReplay* (#1791)
28cbaf9 New compute at interface (#1743)
635ebfc Add SpanningTreePrinter (#1786)
59f3c32 Output allocate patch (#1790)
fe93bf5 Transform propagator skip replay when possible (#1782)
ebf23a5 Fix isIntegralType error msg (#1789)
0c82ecf Disable register reuse across serial broadcast ops (#1787)
33a824d Adding sibling path for MaxInfoSpanningTree (#1776)
86f46aa Fix div(Val, TensorView) (#1778)
d3de227 Fix FusionMaxRootDomainInfoSpanningTreePrintTwice_CUDA (#1781)
ecc7a87 Extend mma dimension and layout checking to support strided batched matmul and tensor contractions (#1761)
```

RUN_TORCHBENCH: nvfuser

ghstack-source-id: cfd5278
Pull Request resolved: #81861
jjsjann123 added a commit that referenced this pull request Jul 23, 2022
Syncing nvfuser devel branch to upstream master. https://github.com/csarofeen/pytorch/

Code changes includes:

- codegen improvements:
  1. Indexing refactor -> Remove reference tensor in predicate indexing logic
  2. MMA Rfactor support for cross-warp and cross-CTA split on K dimension
  3. Grouping grid allreduces across iterations
  4. Swizzle op formulation for non-affine swizzles
  5. Use scheduler_utils to cache inputs and outputs in schedulePointwise
- scheduler refactor
  1. New compute at interface
- transformation propagation refactor on MaxInfoSpanningTree
  1. Added sibling path that is required to generate consistent replay for some cases where `MaxInfoSpanningTree` is used with a selector.
  2. Optimization to skip Transform propagator
  3. SpanningTreePrinter for debugging
- parser update
  1. Fixes `div`
  2. Added `_to_copy`
  3. Broadcast in dim with expand to support expanding to concrete size
  4. Dropout prob extremal patch
- executor patch on caching strides for output allocation

Squashed commits to WAR github API
Commits that's actually in this PR from the devel branch:

```
3b87896 Fix allocation of work buffers and `fused_reduction::ParallelReduce` with unswitch (#1818)
4cae122 schedulePointwise cleanup: - computeAt + InlinePropagator (#1815)
3df9742 Use scheduler_utils to cache inputs and outputs in schedulePointwise (#1811)
03180aa improve broadcast resolution (#1792)
bee6c69 bug fix (#1819)
4413c8f Support PYTORCH_NVFUSER_DUMP=transform_propagator (#1812)
de6b7ca Fix negative position in InlinePropagator (#1813)
10a996c Remove redundant check in schedulePointwise (#1810)
acd5ed4 Swizzle op formulation for non-affine swizzles (#1441)
3ed8330 Kernel args patch to show zero_init buffer (#1809)
037a75a Dropout prob extremal patch (#1804)
282c429 spam nvrtc options (#1783)
3ba6a5f Broadcast in dim with expand (#1794)
fd4be12 remove dead indexing code (#1806)
fa4e6a4 Check siblings in getMaxPosAll (#1805)
025c840 Grouping grid allreduces across iterations (#1755)
37c579e Temporarily disable test requring large shared memory. (#1802)
5f375d0 More cleanup on InlinePropagator (#1800)
8d384da Indexing refactor stage 2 : Remove reference tensor in predicate indexing logic (#1784)
f008140 MMA Rfactor support for cross-warp and cross-CTA split on K dimension (#1554)
76b3cca Add parsing support for `_to_copy` to handle AMP casts. (#1756)
ef04f6c Coding style cleanups (#1798)
38c7f3c InlinePropagator please don't replay (#1797)
3f2c263 validateDomain in TransformPropagator (#1796)
c077085 Use TransformPropagatorWithCheck in many tests (#1795)
d0d0908 Some further cleanup for the new computeAt interface (#1793)
45f5203 Fix TransformReplay::getMatchedLeafPosWithoutReplay* (#1791)
28cbaf9 New compute at interface (#1743)
635ebfc Add SpanningTreePrinter (#1786)
59f3c32 Output allocate patch (#1790)
fe93bf5 Transform propagator skip replay when possible (#1782)
ebf23a5 Fix isIntegralType error msg (#1789)
0c82ecf Disable register reuse across serial broadcast ops (#1787)
33a824d Adding sibling path for MaxInfoSpanningTree (#1776)
86f46aa Fix div(Val, TensorView) (#1778)
d3de227 Fix FusionMaxRootDomainInfoSpanningTreePrintTwice_CUDA (#1781)
ecc7a87 Extend mma dimension and layout checking to support strided batched matmul and tensor contractions (#1761)
```

RUN_TORCHBENCH: nvfuser

Differential Revision: [D38043938](https://our.internmc.facebook.com/intern/diff/D38043938)

[ghstack-poisoned]
jjsjann123 added a commit that referenced this pull request Jul 23, 2022
Syncing nvfuser devel branch to upstream master. https://github.com/csarofeen/pytorch/

Code changes includes:

- codegen improvements:
  1. Indexing refactor -> Remove reference tensor in predicate indexing logic
  2. MMA Rfactor support for cross-warp and cross-CTA split on K dimension
  3. Grouping grid allreduces across iterations
  4. Swizzle op formulation for non-affine swizzles
  5. Use scheduler_utils to cache inputs and outputs in schedulePointwise
- scheduler refactor
  1. New compute at interface
- transformation propagation refactor on MaxInfoSpanningTree
  1. Added sibling path that is required to generate consistent replay for some cases where `MaxInfoSpanningTree` is used with a selector.
  2. Optimization to skip Transform propagator
  3. SpanningTreePrinter for debugging
- parser update
  1. Fixes `div`
  2. Added `_to_copy`
  3. Broadcast in dim with expand to support expanding to concrete size
  4. Dropout prob extremal patch
- executor patch on caching strides for output allocation

Squashed commits to WAR github API
Commits that's actually in this PR from the devel branch:

```
3b87896 Fix allocation of work buffers and `fused_reduction::ParallelReduce` with unswitch (#1818)
4cae122 schedulePointwise cleanup: - computeAt + InlinePropagator (#1815)
3df9742 Use scheduler_utils to cache inputs and outputs in schedulePointwise (#1811)
03180aa improve broadcast resolution (#1792)
bee6c69 bug fix (#1819)
4413c8f Support PYTORCH_NVFUSER_DUMP=transform_propagator (#1812)
de6b7ca Fix negative position in InlinePropagator (#1813)
10a996c Remove redundant check in schedulePointwise (#1810)
acd5ed4 Swizzle op formulation for non-affine swizzles (#1441)
3ed8330 Kernel args patch to show zero_init buffer (#1809)
037a75a Dropout prob extremal patch (#1804)
282c429 spam nvrtc options (#1783)
3ba6a5f Broadcast in dim with expand (#1794)
fd4be12 remove dead indexing code (#1806)
fa4e6a4 Check siblings in getMaxPosAll (#1805)
025c840 Grouping grid allreduces across iterations (#1755)
37c579e Temporarily disable test requring large shared memory. (#1802)
5f375d0 More cleanup on InlinePropagator (#1800)
8d384da Indexing refactor stage 2 : Remove reference tensor in predicate indexing logic (#1784)
f008140 MMA Rfactor support for cross-warp and cross-CTA split on K dimension (#1554)
76b3cca Add parsing support for `_to_copy` to handle AMP casts. (#1756)
ef04f6c Coding style cleanups (#1798)
38c7f3c InlinePropagator please don't replay (#1797)
3f2c263 validateDomain in TransformPropagator (#1796)
c077085 Use TransformPropagatorWithCheck in many tests (#1795)
d0d0908 Some further cleanup for the new computeAt interface (#1793)
45f5203 Fix TransformReplay::getMatchedLeafPosWithoutReplay* (#1791)
28cbaf9 New compute at interface (#1743)
635ebfc Add SpanningTreePrinter (#1786)
59f3c32 Output allocate patch (#1790)
fe93bf5 Transform propagator skip replay when possible (#1782)
ebf23a5 Fix isIntegralType error msg (#1789)
0c82ecf Disable register reuse across serial broadcast ops (#1787)
33a824d Adding sibling path for MaxInfoSpanningTree (#1776)
86f46aa Fix div(Val, TensorView) (#1778)
d3de227 Fix FusionMaxRootDomainInfoSpanningTreePrintTwice_CUDA (#1781)
ecc7a87 Extend mma dimension and layout checking to support strided batched matmul and tensor contractions (#1761)
```

RUN_TORCHBENCH: nvfuser

ghstack-source-id: 93c6b1e
Pull Request resolved: #81861
jjsjann123 added a commit that referenced this pull request Jul 26, 2022
Syncing nvfuser devel branch to upstream master. https://github.com/csarofeen/pytorch/

Code changes includes:

- codegen improvements:
  1. Indexing refactor -> Remove reference tensor in predicate indexing logic
  2. MMA Rfactor support for cross-warp and cross-CTA split on K dimension
  3. Grouping grid allreduces across iterations
  4. Swizzle op formulation for non-affine swizzles
  5. Use scheduler_utils to cache inputs and outputs in schedulePointwise
- scheduler refactor
  1. New compute at interface
- transformation propagation refactor on MaxInfoSpanningTree
  1. Added sibling path that is required to generate consistent replay for some cases where `MaxInfoSpanningTree` is used with a selector.
  2. Optimization to skip Transform propagator
  3. SpanningTreePrinter for debugging
- parser update
  1. Fixes `div`
  2. Added `_to_copy`
  3. Broadcast in dim with expand to support expanding to concrete size
  4. Dropout prob extremal patch
- executor patch on caching strides for output allocation

Squashed commits to WAR github API
Commits that's actually in this PR from the devel branch:

```
3b87896 Fix allocation of work buffers and `fused_reduction::ParallelReduce` with unswitch (#1818)
4cae122 schedulePointwise cleanup: - computeAt + InlinePropagator (#1815)
3df9742 Use scheduler_utils to cache inputs and outputs in schedulePointwise (#1811)
03180aa improve broadcast resolution (#1792)
bee6c69 bug fix (#1819)
4413c8f Support PYTORCH_NVFUSER_DUMP=transform_propagator (#1812)
de6b7ca Fix negative position in InlinePropagator (#1813)
10a996c Remove redundant check in schedulePointwise (#1810)
acd5ed4 Swizzle op formulation for non-affine swizzles (#1441)
3ed8330 Kernel args patch to show zero_init buffer (#1809)
037a75a Dropout prob extremal patch (#1804)
282c429 spam nvrtc options (#1783)
3ba6a5f Broadcast in dim with expand (#1794)
fd4be12 remove dead indexing code (#1806)
fa4e6a4 Check siblings in getMaxPosAll (#1805)
025c840 Grouping grid allreduces across iterations (#1755)
37c579e Temporarily disable test requring large shared memory. (#1802)
5f375d0 More cleanup on InlinePropagator (#1800)
8d384da Indexing refactor stage 2 : Remove reference tensor in predicate indexing logic (#1784)
f008140 MMA Rfactor support for cross-warp and cross-CTA split on K dimension (#1554)
76b3cca Add parsing support for `_to_copy` to handle AMP casts. (#1756)
ef04f6c Coding style cleanups (#1798)
38c7f3c InlinePropagator please don't replay (#1797)
3f2c263 validateDomain in TransformPropagator (#1796)
c077085 Use TransformPropagatorWithCheck in many tests (#1795)
d0d0908 Some further cleanup for the new computeAt interface (#1793)
45f5203 Fix TransformReplay::getMatchedLeafPosWithoutReplay* (#1791)
28cbaf9 New compute at interface (#1743)
635ebfc Add SpanningTreePrinter (#1786)
59f3c32 Output allocate patch (#1790)
fe93bf5 Transform propagator skip replay when possible (#1782)
ebf23a5 Fix isIntegralType error msg (#1789)
0c82ecf Disable register reuse across serial broadcast ops (#1787)
33a824d Adding sibling path for MaxInfoSpanningTree (#1776)
86f46aa Fix div(Val, TensorView) (#1778)
d3de227 Fix FusionMaxRootDomainInfoSpanningTreePrintTwice_CUDA (#1781)
ecc7a87 Extend mma dimension and layout checking to support strided batched matmul and tensor contractions (#1761)
```

RUN_TORCHBENCH: nvfuser

Differential Revision: [D38043938](https://our.internmc.facebook.com/intern/diff/D38043938)

[ghstack-poisoned]
jjsjann123 added a commit that referenced this pull request Jul 26, 2022
Syncing nvfuser devel branch to upstream master. https://github.com/csarofeen/pytorch/

Code changes includes:

- codegen improvements:
  1. Indexing refactor -> Remove reference tensor in predicate indexing logic
  2. MMA Rfactor support for cross-warp and cross-CTA split on K dimension
  3. Grouping grid allreduces across iterations
  4. Swizzle op formulation for non-affine swizzles
  5. Use scheduler_utils to cache inputs and outputs in schedulePointwise
- scheduler refactor
  1. New compute at interface
- transformation propagation refactor on MaxInfoSpanningTree
  1. Added sibling path that is required to generate consistent replay for some cases where `MaxInfoSpanningTree` is used with a selector.
  2. Optimization to skip Transform propagator
  3. SpanningTreePrinter for debugging
- parser update
  1. Fixes `div`
  2. Added `_to_copy`
  3. Broadcast in dim with expand to support expanding to concrete size
  4. Dropout prob extremal patch
- executor patch on caching strides for output allocation

Squashed commits to WAR github API
Commits that's actually in this PR from the devel branch:

```
3b87896 Fix allocation of work buffers and `fused_reduction::ParallelReduce` with unswitch (#1818)
4cae122 schedulePointwise cleanup: - computeAt + InlinePropagator (#1815)
3df9742 Use scheduler_utils to cache inputs and outputs in schedulePointwise (#1811)
03180aa improve broadcast resolution (#1792)
bee6c69 bug fix (#1819)
4413c8f Support PYTORCH_NVFUSER_DUMP=transform_propagator (#1812)
de6b7ca Fix negative position in InlinePropagator (#1813)
10a996c Remove redundant check in schedulePointwise (#1810)
acd5ed4 Swizzle op formulation for non-affine swizzles (#1441)
3ed8330 Kernel args patch to show zero_init buffer (#1809)
037a75a Dropout prob extremal patch (#1804)
282c429 spam nvrtc options (#1783)
3ba6a5f Broadcast in dim with expand (#1794)
fd4be12 remove dead indexing code (#1806)
fa4e6a4 Check siblings in getMaxPosAll (#1805)
025c840 Grouping grid allreduces across iterations (#1755)
37c579e Temporarily disable test requring large shared memory. (#1802)
5f375d0 More cleanup on InlinePropagator (#1800)
8d384da Indexing refactor stage 2 : Remove reference tensor in predicate indexing logic (#1784)
f008140 MMA Rfactor support for cross-warp and cross-CTA split on K dimension (#1554)
76b3cca Add parsing support for `_to_copy` to handle AMP casts. (#1756)
ef04f6c Coding style cleanups (#1798)
38c7f3c InlinePropagator please don't replay (#1797)
3f2c263 validateDomain in TransformPropagator (#1796)
c077085 Use TransformPropagatorWithCheck in many tests (#1795)
d0d0908 Some further cleanup for the new computeAt interface (#1793)
45f5203 Fix TransformReplay::getMatchedLeafPosWithoutReplay* (#1791)
28cbaf9 New compute at interface (#1743)
635ebfc Add SpanningTreePrinter (#1786)
59f3c32 Output allocate patch (#1790)
fe93bf5 Transform propagator skip replay when possible (#1782)
ebf23a5 Fix isIntegralType error msg (#1789)
0c82ecf Disable register reuse across serial broadcast ops (#1787)
33a824d Adding sibling path for MaxInfoSpanningTree (#1776)
86f46aa Fix div(Val, TensorView) (#1778)
d3de227 Fix FusionMaxRootDomainInfoSpanningTreePrintTwice_CUDA (#1781)
ecc7a87 Extend mma dimension and layout checking to support strided batched matmul and tensor contractions (#1761)
```

RUN_TORCHBENCH: nvfuser

Differential Revision: [D38043938](https://our.internmc.facebook.com/intern/diff/D38043938)

[ghstack-poisoned]
jjsjann123 added a commit that referenced this pull request Jul 27, 2022
Syncing nvfuser devel branch to upstream master. https://github.com/csarofeen/pytorch/

Code changes includes:

- codegen improvements:
  1. Indexing refactor -> Remove reference tensor in predicate indexing logic
  2. MMA Rfactor support for cross-warp and cross-CTA split on K dimension
  3. Grouping grid allreduces across iterations
  4. Swizzle op formulation for non-affine swizzles
  5. Use scheduler_utils to cache inputs and outputs in schedulePointwise
- scheduler refactor
  1. New compute at interface
- transformation propagation refactor on MaxInfoSpanningTree
  1. Added sibling path that is required to generate consistent replay for some cases where `MaxInfoSpanningTree` is used with a selector.
  2. Optimization to skip Transform propagator
  3. SpanningTreePrinter for debugging
- parser update
  1. Fixes `div`
  2. Added `_to_copy`
  3. Broadcast in dim with expand to support expanding to concrete size
  4. Dropout prob extremal patch
- executor patch on caching strides for output allocation

Squashed commits to WAR github API
Commits that's actually in this PR from the devel branch:

```
3b87896 Fix allocation of work buffers and `fused_reduction::ParallelReduce` with unswitch (#1818)
4cae122 schedulePointwise cleanup: - computeAt + InlinePropagator (#1815)
3df9742 Use scheduler_utils to cache inputs and outputs in schedulePointwise (#1811)
03180aa improve broadcast resolution (#1792)
bee6c69 bug fix (#1819)
4413c8f Support PYTORCH_NVFUSER_DUMP=transform_propagator (#1812)
de6b7ca Fix negative position in InlinePropagator (#1813)
10a996c Remove redundant check in schedulePointwise (#1810)
acd5ed4 Swizzle op formulation for non-affine swizzles (#1441)
3ed8330 Kernel args patch to show zero_init buffer (#1809)
037a75a Dropout prob extremal patch (#1804)
282c429 spam nvrtc options (#1783)
3ba6a5f Broadcast in dim with expand (#1794)
fd4be12 remove dead indexing code (#1806)
fa4e6a4 Check siblings in getMaxPosAll (#1805)
025c840 Grouping grid allreduces across iterations (#1755)
37c579e Temporarily disable test requring large shared memory. (#1802)
5f375d0 More cleanup on InlinePropagator (#1800)
8d384da Indexing refactor stage 2 : Remove reference tensor in predicate indexing logic (#1784)
f008140 MMA Rfactor support for cross-warp and cross-CTA split on K dimension (#1554)
76b3cca Add parsing support for `_to_copy` to handle AMP casts. (#1756)
ef04f6c Coding style cleanups (#1798)
38c7f3c InlinePropagator please don't replay (#1797)
3f2c263 validateDomain in TransformPropagator (#1796)
c077085 Use TransformPropagatorWithCheck in many tests (#1795)
d0d0908 Some further cleanup for the new computeAt interface (#1793)
45f5203 Fix TransformReplay::getMatchedLeafPosWithoutReplay* (#1791)
28cbaf9 New compute at interface (#1743)
635ebfc Add SpanningTreePrinter (#1786)
59f3c32 Output allocate patch (#1790)
fe93bf5 Transform propagator skip replay when possible (#1782)
ebf23a5 Fix isIntegralType error msg (#1789)
0c82ecf Disable register reuse across serial broadcast ops (#1787)
33a824d Adding sibling path for MaxInfoSpanningTree (#1776)
86f46aa Fix div(Val, TensorView) (#1778)
d3de227 Fix FusionMaxRootDomainInfoSpanningTreePrintTwice_CUDA (#1781)
ecc7a87 Extend mma dimension and layout checking to support strided batched matmul and tensor contractions (#1761)
```

RUN_TORCHBENCH: nvfuser

Differential Revision: [D38043938](https://our.internmc.facebook.com/intern/diff/D38043938)

[ghstack-poisoned]
jjsjann123 added a commit that referenced this pull request Jul 27, 2022
Syncing nvfuser devel branch to upstream master. https://github.com/csarofeen/pytorch/

Code changes includes:

- codegen improvements:
  1. Indexing refactor -> Remove reference tensor in predicate indexing logic
  2. MMA Rfactor support for cross-warp and cross-CTA split on K dimension
  3. Grouping grid allreduces across iterations
  4. Swizzle op formulation for non-affine swizzles
  5. Use scheduler_utils to cache inputs and outputs in schedulePointwise
- scheduler refactor
  1. New compute at interface
- transformation propagation refactor on MaxInfoSpanningTree
  1. Added sibling path that is required to generate consistent replay for some cases where `MaxInfoSpanningTree` is used with a selector.
  2. Optimization to skip Transform propagator
  3. SpanningTreePrinter for debugging
- parser update
  1. Fixes `div`
  2. Added `_to_copy`
  3. Broadcast in dim with expand to support expanding to concrete size
  4. Dropout prob extremal patch
- executor patch on caching strides for output allocation

Squashed commits to WAR github API
Commits that's actually in this PR from the devel branch:

```
3b87896 Fix allocation of work buffers and `fused_reduction::ParallelReduce` with unswitch (#1818)
4cae122 schedulePointwise cleanup: - computeAt + InlinePropagator (#1815)
3df9742 Use scheduler_utils to cache inputs and outputs in schedulePointwise (#1811)
03180aa improve broadcast resolution (#1792)
bee6c69 bug fix (#1819)
4413c8f Support PYTORCH_NVFUSER_DUMP=transform_propagator (#1812)
de6b7ca Fix negative position in InlinePropagator (#1813)
10a996c Remove redundant check in schedulePointwise (#1810)
acd5ed4 Swizzle op formulation for non-affine swizzles (#1441)
3ed8330 Kernel args patch to show zero_init buffer (#1809)
037a75a Dropout prob extremal patch (#1804)
282c429 spam nvrtc options (#1783)
3ba6a5f Broadcast in dim with expand (#1794)
fd4be12 remove dead indexing code (#1806)
fa4e6a4 Check siblings in getMaxPosAll (#1805)
025c840 Grouping grid allreduces across iterations (#1755)
37c579e Temporarily disable test requring large shared memory. (#1802)
5f375d0 More cleanup on InlinePropagator (#1800)
8d384da Indexing refactor stage 2 : Remove reference tensor in predicate indexing logic (#1784)
f008140 MMA Rfactor support for cross-warp and cross-CTA split on K dimension (#1554)
76b3cca Add parsing support for `_to_copy` to handle AMP casts. (#1756)
ef04f6c Coding style cleanups (#1798)
38c7f3c InlinePropagator please don't replay (#1797)
3f2c263 validateDomain in TransformPropagator (#1796)
c077085 Use TransformPropagatorWithCheck in many tests (#1795)
d0d0908 Some further cleanup for the new computeAt interface (#1793)
45f5203 Fix TransformReplay::getMatchedLeafPosWithoutReplay* (#1791)
28cbaf9 New compute at interface (#1743)
635ebfc Add SpanningTreePrinter (#1786)
59f3c32 Output allocate patch (#1790)
fe93bf5 Transform propagator skip replay when possible (#1782)
ebf23a5 Fix isIntegralType error msg (#1789)
0c82ecf Disable register reuse across serial broadcast ops (#1787)
33a824d Adding sibling path for MaxInfoSpanningTree (#1776)
86f46aa Fix div(Val, TensorView) (#1778)
d3de227 Fix FusionMaxRootDomainInfoSpanningTreePrintTwice_CUDA (#1781)
ecc7a87 Extend mma dimension and layout checking to support strided batched matmul and tensor contractions (#1761)
```

RUN_TORCHBENCH: nvfuser

Differential Revision: [D38043938](https://our.internmc.facebook.com/intern/diff/D38043938)

[ghstack-poisoned]
jjsjann123 added a commit that referenced this pull request Jul 27, 2022
Syncing nvfuser devel branch to upstream master. https://github.com/csarofeen/pytorch/

Code changes includes:

- codegen improvements:
  1. Indexing refactor -> Remove reference tensor in predicate indexing logic
  2. MMA Rfactor support for cross-warp and cross-CTA split on K dimension
  3. Grouping grid allreduces across iterations
  4. Swizzle op formulation for non-affine swizzles
  5. Use scheduler_utils to cache inputs and outputs in schedulePointwise
- scheduler refactor
  1. New compute at interface
- transformation propagation refactor on MaxInfoSpanningTree
  1. Added sibling path that is required to generate consistent replay for some cases where `MaxInfoSpanningTree` is used with a selector.
  2. Optimization to skip Transform propagator
  3. SpanningTreePrinter for debugging
- parser update
  1. Fixes `div`
  2. Added `_to_copy`
  3. Broadcast in dim with expand to support expanding to concrete size
  4. Dropout prob extremal patch
- executor patch on caching strides for output allocation

Squashed commits to WAR github API
Commits that's actually in this PR from the devel branch:

```
3b87896 Fix allocation of work buffers and `fused_reduction::ParallelReduce` with unswitch (#1818)
4cae122 schedulePointwise cleanup: - computeAt + InlinePropagator (#1815)
3df9742 Use scheduler_utils to cache inputs and outputs in schedulePointwise (#1811)
03180aa improve broadcast resolution (#1792)
bee6c69 bug fix (#1819)
4413c8f Support PYTORCH_NVFUSER_DUMP=transform_propagator (#1812)
de6b7ca Fix negative position in InlinePropagator (#1813)
10a996c Remove redundant check in schedulePointwise (#1810)
acd5ed4 Swizzle op formulation for non-affine swizzles (#1441)
3ed8330 Kernel args patch to show zero_init buffer (#1809)
037a75a Dropout prob extremal patch (#1804)
282c429 spam nvrtc options (#1783)
3ba6a5f Broadcast in dim with expand (#1794)
fd4be12 remove dead indexing code (#1806)
fa4e6a4 Check siblings in getMaxPosAll (#1805)
025c840 Grouping grid allreduces across iterations (#1755)
37c579e Temporarily disable test requring large shared memory. (#1802)
5f375d0 More cleanup on InlinePropagator (#1800)
8d384da Indexing refactor stage 2 : Remove reference tensor in predicate indexing logic (#1784)
f008140 MMA Rfactor support for cross-warp and cross-CTA split on K dimension (#1554)
76b3cca Add parsing support for `_to_copy` to handle AMP casts. (#1756)
ef04f6c Coding style cleanups (#1798)
38c7f3c InlinePropagator please don't replay (#1797)
3f2c263 validateDomain in TransformPropagator (#1796)
c077085 Use TransformPropagatorWithCheck in many tests (#1795)
d0d0908 Some further cleanup for the new computeAt interface (#1793)
45f5203 Fix TransformReplay::getMatchedLeafPosWithoutReplay* (#1791)
28cbaf9 New compute at interface (#1743)
635ebfc Add SpanningTreePrinter (#1786)
59f3c32 Output allocate patch (#1790)
fe93bf5 Transform propagator skip replay when possible (#1782)
ebf23a5 Fix isIntegralType error msg (#1789)
0c82ecf Disable register reuse across serial broadcast ops (#1787)
33a824d Adding sibling path for MaxInfoSpanningTree (#1776)
86f46aa Fix div(Val, TensorView) (#1778)
d3de227 Fix FusionMaxRootDomainInfoSpanningTreePrintTwice_CUDA (#1781)
ecc7a87 Extend mma dimension and layout checking to support strided batched matmul and tensor contractions (#1761)
```

RUN_TORCHBENCH: nvfuser

Differential Revision: [D38043938](https://our.internmc.facebook.com/intern/diff/D38043938)

[ghstack-poisoned]
jjsjann123 added a commit that referenced this pull request Jul 27, 2022
Syncing nvfuser devel branch to upstream master. https://github.com/csarofeen/pytorch/

Code changes includes:

- codegen improvements:
  1. Indexing refactor -> Remove reference tensor in predicate indexing logic
  2. MMA Rfactor support for cross-warp and cross-CTA split on K dimension
  3. Grouping grid allreduces across iterations
  4. Swizzle op formulation for non-affine swizzles
  5. Use scheduler_utils to cache inputs and outputs in schedulePointwise
- scheduler refactor
  1. New compute at interface
- transformation propagation refactor on MaxInfoSpanningTree
  1. Added sibling path that is required to generate consistent replay for some cases where `MaxInfoSpanningTree` is used with a selector.
  2. Optimization to skip Transform propagator
  3. SpanningTreePrinter for debugging
- parser update
  1. Fixes `div`
  2. Added `_to_copy`
  3. Broadcast in dim with expand to support expanding to concrete size
  4. Dropout prob extremal patch
- executor patch on caching strides for output allocation

Squashed commits to WAR github API
Commits that's actually in this PR from the devel branch:

```
3b87896 Fix allocation of work buffers and `fused_reduction::ParallelReduce` with unswitch (#1818)
4cae122 schedulePointwise cleanup: - computeAt + InlinePropagator (#1815)
3df9742 Use scheduler_utils to cache inputs and outputs in schedulePointwise (#1811)
03180aa improve broadcast resolution (#1792)
bee6c69 bug fix (#1819)
4413c8f Support PYTORCH_NVFUSER_DUMP=transform_propagator (#1812)
de6b7ca Fix negative position in InlinePropagator (#1813)
10a996c Remove redundant check in schedulePointwise (#1810)
acd5ed4 Swizzle op formulation for non-affine swizzles (#1441)
3ed8330 Kernel args patch to show zero_init buffer (#1809)
037a75a Dropout prob extremal patch (#1804)
282c429 spam nvrtc options (#1783)
3ba6a5f Broadcast in dim with expand (#1794)
fd4be12 remove dead indexing code (#1806)
fa4e6a4 Check siblings in getMaxPosAll (#1805)
025c840 Grouping grid allreduces across iterations (#1755)
37c579e Temporarily disable test requring large shared memory. (#1802)
5f375d0 More cleanup on InlinePropagator (#1800)
8d384da Indexing refactor stage 2 : Remove reference tensor in predicate indexing logic (#1784)
f008140 MMA Rfactor support for cross-warp and cross-CTA split on K dimension (#1554)
76b3cca Add parsing support for `_to_copy` to handle AMP casts. (#1756)
ef04f6c Coding style cleanups (#1798)
38c7f3c InlinePropagator please don't replay (#1797)
3f2c263 validateDomain in TransformPropagator (#1796)
c077085 Use TransformPropagatorWithCheck in many tests (#1795)
d0d0908 Some further cleanup for the new computeAt interface (#1793)
45f5203 Fix TransformReplay::getMatchedLeafPosWithoutReplay* (#1791)
28cbaf9 New compute at interface (#1743)
635ebfc Add SpanningTreePrinter (#1786)
59f3c32 Output allocate patch (#1790)
fe93bf5 Transform propagator skip replay when possible (#1782)
ebf23a5 Fix isIntegralType error msg (#1789)
0c82ecf Disable register reuse across serial broadcast ops (#1787)
33a824d Adding sibling path for MaxInfoSpanningTree (#1776)
86f46aa Fix div(Val, TensorView) (#1778)
d3de227 Fix FusionMaxRootDomainInfoSpanningTreePrintTwice_CUDA (#1781)
ecc7a87 Extend mma dimension and layout checking to support strided batched matmul and tensor contractions (#1761)
```

RUN_TORCHBENCH: nvfuser

Differential Revision: [D38043938](https://our.internmc.facebook.com/intern/diff/D38043938)

[ghstack-poisoned]
jjsjann123 added a commit that referenced this pull request Jul 27, 2022
Syncing nvfuser devel branch to upstream master. https://github.com/csarofeen/pytorch/

Code changes includes:

- codegen improvements:
  1. Indexing refactor -> Remove reference tensor in predicate indexing logic
  2. MMA Rfactor support for cross-warp and cross-CTA split on K dimension
  3. Grouping grid allreduces across iterations
  4. Swizzle op formulation for non-affine swizzles
  5. Use scheduler_utils to cache inputs and outputs in schedulePointwise
- scheduler refactor
  1. New compute at interface
- transformation propagation refactor on MaxInfoSpanningTree
  1. Added sibling path that is required to generate consistent replay for some cases where `MaxInfoSpanningTree` is used with a selector.
  2. Optimization to skip Transform propagator
  3. SpanningTreePrinter for debugging
- parser update
  1. Fixes `div`
  2. Added `_to_copy`
  3. Broadcast in dim with expand to support expanding to concrete size
  4. Dropout prob extremal patch
- executor patch on caching strides for output allocation

Squashed commits to WAR github API
Commits that's actually in this PR from the devel branch:

```
3b87896 Fix allocation of work buffers and `fused_reduction::ParallelReduce` with unswitch (#1818)
4cae122 schedulePointwise cleanup: - computeAt + InlinePropagator (#1815)
3df9742 Use scheduler_utils to cache inputs and outputs in schedulePointwise (#1811)
03180aa improve broadcast resolution (#1792)
bee6c69 bug fix (#1819)
4413c8f Support PYTORCH_NVFUSER_DUMP=transform_propagator (#1812)
de6b7ca Fix negative position in InlinePropagator (#1813)
10a996c Remove redundant check in schedulePointwise (#1810)
acd5ed4 Swizzle op formulation for non-affine swizzles (#1441)
3ed8330 Kernel args patch to show zero_init buffer (#1809)
037a75a Dropout prob extremal patch (#1804)
282c429 spam nvrtc options (#1783)
3ba6a5f Broadcast in dim with expand (#1794)
fd4be12 remove dead indexing code (#1806)
fa4e6a4 Check siblings in getMaxPosAll (#1805)
025c840 Grouping grid allreduces across iterations (#1755)
37c579e Temporarily disable test requring large shared memory. (#1802)
5f375d0 More cleanup on InlinePropagator (#1800)
8d384da Indexing refactor stage 2 : Remove reference tensor in predicate indexing logic (#1784)
f008140 MMA Rfactor support for cross-warp and cross-CTA split on K dimension (#1554)
76b3cca Add parsing support for `_to_copy` to handle AMP casts. (#1756)
ef04f6c Coding style cleanups (#1798)
38c7f3c InlinePropagator please don't replay (#1797)
3f2c263 validateDomain in TransformPropagator (#1796)
c077085 Use TransformPropagatorWithCheck in many tests (#1795)
d0d0908 Some further cleanup for the new computeAt interface (#1793)
45f5203 Fix TransformReplay::getMatchedLeafPosWithoutReplay* (#1791)
28cbaf9 New compute at interface (#1743)
635ebfc Add SpanningTreePrinter (#1786)
59f3c32 Output allocate patch (#1790)
fe93bf5 Transform propagator skip replay when possible (#1782)
ebf23a5 Fix isIntegralType error msg (#1789)
0c82ecf Disable register reuse across serial broadcast ops (#1787)
33a824d Adding sibling path for MaxInfoSpanningTree (#1776)
86f46aa Fix div(Val, TensorView) (#1778)
d3de227 Fix FusionMaxRootDomainInfoSpanningTreePrintTwice_CUDA (#1781)
ecc7a87 Extend mma dimension and layout checking to support strided batched matmul and tensor contractions (#1761)
```

RUN_TORCHBENCH: nvfuser

ghstack-source-id: a74f653
Pull Request resolved: #81861
pytorchmergebot pushed a commit that referenced this pull request Jul 28, 2022
Syncing nvfuser devel branch to upstream master. https://github.com/csarofeen/pytorch/

Code changes includes:

- codegen improvements:
  1. Indexing refactor -> Remove reference tensor in predicate indexing logic
  2. MMA Rfactor support for cross-warp and cross-CTA split on K dimension
  3. Grouping grid allreduces across iterations
  4. Swizzle op formulation for non-affine swizzles
  5. Use scheduler_utils to cache inputs and outputs in schedulePointwise
- scheduler refactor
  1. New compute at interface
- transformation propagation refactor on MaxInfoSpanningTree
  1. Added sibling path that is required to generate consistent replay for some cases where `MaxInfoSpanningTree` is used with a selector.
  2. Optimization to skip Transform propagator
  3. SpanningTreePrinter for debugging
- parser update
  1. Fixes `div`
  2. Added `_to_copy`
  3. Broadcast in dim with expand to support expanding to concrete size
  4. Dropout prob extremal patch
- executor patch on caching strides for output allocation

Squashed commits to WAR github API
Commits that's actually in this PR from the devel branch:

```
3b87896 Fix allocation of work buffers and `fused_reduction::ParallelReduce` with unswitch (#1818)
4cae122 schedulePointwise cleanup: - computeAt + InlinePropagator (#1815)
3df9742 Use scheduler_utils to cache inputs and outputs in schedulePointwise (#1811)
03180aa improve broadcast resolution (#1792)
bee6c69 bug fix (#1819)
4413c8f Support PYTORCH_NVFUSER_DUMP=transform_propagator (#1812)
de6b7ca Fix negative position in InlinePropagator (#1813)
10a996c Remove redundant check in schedulePointwise (#1810)
acd5ed4 Swizzle op formulation for non-affine swizzles (#1441)
3ed8330 Kernel args patch to show zero_init buffer (#1809)
037a75a Dropout prob extremal patch (#1804)
282c429 spam nvrtc options (#1783)
3ba6a5f Broadcast in dim with expand (#1794)
fd4be12 remove dead indexing code (#1806)
fa4e6a4 Check siblings in getMaxPosAll (#1805)
025c840 Grouping grid allreduces across iterations (#1755)
37c579e Temporarily disable test requring large shared memory. (#1802)
5f375d0 More cleanup on InlinePropagator (#1800)
8d384da Indexing refactor stage 2 : Remove reference tensor in predicate indexing logic (#1784)
f008140 MMA Rfactor support for cross-warp and cross-CTA split on K dimension (#1554)
76b3cca Add parsing support for `_to_copy` to handle AMP casts. (#1756)
ef04f6c Coding style cleanups (#1798)
38c7f3c InlinePropagator please don't replay (#1797)
3f2c263 validateDomain in TransformPropagator (#1796)
c077085 Use TransformPropagatorWithCheck in many tests (#1795)
d0d0908 Some further cleanup for the new computeAt interface (#1793)
45f5203 Fix TransformReplay::getMatchedLeafPosWithoutReplay* (#1791)
28cbaf9 New compute at interface (#1743)
635ebfc Add SpanningTreePrinter (#1786)
59f3c32 Output allocate patch (#1790)
fe93bf5 Transform propagator skip replay when possible (#1782)
ebf23a5 Fix isIntegralType error msg (#1789)
0c82ecf Disable register reuse across serial broadcast ops (#1787)
33a824d Adding sibling path for MaxInfoSpanningTree (#1776)
86f46aa Fix div(Val, TensorView) (#1778)
d3de227 Fix FusionMaxRootDomainInfoSpanningTreePrintTwice_CUDA (#1781)
ecc7a87 Extend mma dimension and layout checking to support strided batched matmul and tensor contractions (#1761)
```

RUN_TORCHBENCH: nvfuser

Differential Revision: [D38043938](https://our.internmc.facebook.com/intern/diff/D38043938)
Pull Request resolved: #81861
Approved by: https://github.com/davidberard98
facebook-github-bot pushed a commit that referenced this pull request Jul 28, 2022
Summary:
Pull Request resolved: #81861

Syncing nvfuser devel branch to upstream master. https://github.com/csarofeen/pytorch/

Code changes includes:

- codegen improvements:
  1. Indexing refactor -> Remove reference tensor in predicate indexing logic
  2. MMA Rfactor support for cross-warp and cross-CTA split on K dimension
  3. Grouping grid allreduces across iterations
  4. Swizzle op formulation for non-affine swizzles
  5. Use scheduler_utils to cache inputs and outputs in schedulePointwise
- scheduler refactor
  1. New compute at interface
- transformation propagation refactor on MaxInfoSpanningTree
  1. Added sibling path that is required to generate consistent replay for some cases where `MaxInfoSpanningTree` is used with a selector.
  2. Optimization to skip Transform propagator
  3. SpanningTreePrinter for debugging
- parser update
  1. Fixes `div`
  2. Added `_to_copy`
  3. Broadcast in dim with expand to support expanding to concrete size
  4. Dropout prob extremal patch
- executor patch on caching strides for output allocation

Squashed commits to WAR github API
Commits that's actually in this PR from the devel branch:

```
3b87896 Fix allocation of work buffers and `fused_reduction::ParallelReduce` with unswitch (#1818)
4cae122 schedulePointwise cleanup: - computeAt + InlinePropagator (#1815)
3df9742 Use scheduler_utils to cache inputs and outputs in schedulePointwise (#1811)
03180aa improve broadcast resolution (#1792)
bee6c69 bug fix (#1819)
4413c8f Support PYTORCH_NVFUSER_DUMP=transform_propagator (#1812)
de6b7ca Fix negative position in InlinePropagator (#1813)
10a996c Remove redundant check in schedulePointwise (#1810)
acd5ed4 Swizzle op formulation for non-affine swizzles (#1441)
3ed8330 Kernel args patch to show zero_init buffer (#1809)
037a75a Dropout prob extremal patch (#1804)
282c429 spam nvrtc options (#1783)
3ba6a5f Broadcast in dim with expand (#1794)
fd4be12 remove dead indexing code (#1806)
fa4e6a4 Check siblings in getMaxPosAll (#1805)
025c840 Grouping grid allreduces across iterations (#1755)
37c579e Temporarily disable test requring large shared memory. (#1802)
5f375d0 More cleanup on InlinePropagator (#1800)
8d384da Indexing refactor stage 2 : Remove reference tensor in predicate indexing logic (#1784)
f008140 MMA Rfactor support for cross-warp and cross-CTA split on K dimension (#1554)
76b3cca Add parsing support for `_to_copy` to handle AMP casts. (#1756)
ef04f6c Coding style cleanups (#1798)
38c7f3c InlinePropagator please don't replay (#1797)
3f2c263 validateDomain in TransformPropagator (#1796)
c077085 Use TransformPropagatorWithCheck in many tests (#1795)
d0d0908 Some further cleanup for the new computeAt interface (#1793)
45f5203 Fix TransformReplay::getMatchedLeafPosWithoutReplay* (#1791)
28cbaf9 New compute at interface (#1743)
635ebfc Add SpanningTreePrinter (#1786)
59f3c32 Output allocate patch (#1790)
fe93bf5 Transform propagator skip replay when possible (#1782)
ebf23a5 Fix isIntegralType error msg (#1789)
0c82ecf Disable register reuse across serial broadcast ops (#1787)
33a824d Adding sibling path for MaxInfoSpanningTree (#1776)
86f46aa Fix div(Val, TensorView) (#1778)
d3de227 Fix FusionMaxRootDomainInfoSpanningTreePrintTwice_CUDA (#1781)
ecc7a87 Extend mma dimension and layout checking to support strided batched matmul and tensor contractions (#1761)
```

RUN_TORCHBENCH: nvfuser

Test Plan: Imported from OSS

Reviewed By: samdow

Differential Revision: D38043938

Pulled By: davidberard98

fbshipit-source-id: b94245f83dab6faee31e0c154d3b969bddeb3d47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants