Skip to content

Conversation

@jiayisunx
Copy link
Collaborator

@jiayisunx jiayisunx commented Sep 6, 2024

Stack from ghstack (oldest at bottom):

Summary
This PR Relaxes the conditions for loop split to support dynamic shape cases.
Now the conditions that need to be met to apply loop split optimization are as follows:

  1. No reduction and no mudular index for all nodes.
  2. The indexing_exprs of all nodes contain only one (or more, but all the same) division, where the divisor is an integer, the dividend is one of the iter_vars, and this var, i.e. the dimension that needs to be split, is contiguous in all other indexing_exprs.

Example:

import torch
import torch.nn as nn

class GN(torch.nn.Module):
    def __init__(self, num_groups, num_channels):
        super(GN, self).__init__()
        self.gn = nn.GroupNorm(num_groups, num_channels)

    def forward(self, x):
        return self.gn(x)
    
input = torch.randn(2, 960, 96, 96).to(memory_format=torch.channels_last)
m = GN(32, 960).eval()
compiled_m = torch.compile(m, dynamic=True)

with torch.no_grad():
    compiled_m(input)

Before loop split, the node's var_ranges: {z0: s0, z1: s2, z2: s2, z3: 960} and indexing_exprs: {'index0': 960*s2**2*z0 + 960*s2*z1 + 960*z2 + z3, 'index1': 32*z0 + (z3//30), 'index2': 30*s2**2, 'index3': z3, 'index4': 960*s2*z0*((s2**2//s2)) + 960*z1*((s2**2//s2)) + 960*z2 + z3}. After loop split z3 will split to 30*z3 + z4, then the node's var_ranges will be changed to {z0: s0, z1: s2, z2: s2, z3: 32, z4: 30} and indexing_exprs will be changed to {'index0': 960*s2**2*z0 + 960*s2*z1 + 960*z2 + 30*z3 + z4, 'index1': 32*z0 + z3, 'index2': 30*s2**2, 'index3': 30*z3 + z4, 'index4': 960*s2*z0*((s2**2//s2)) + 960*z1*((s2**2//s2)) + 960*z2 + 30*z3 + z4}

Generated code:

  • Before:
cpp_fused_native_group_norm_0 = async_compile.cpp_pybinding(['const float*', 'const float*', 'const float*', 'float*', 'float*', 'float*', 'const int64_t', 'const int64_t'], '''
#include "/tmp/torchinductor_jiayisun/32/c32dcqa3qidvmunis4lucp3dhoicleq5qjfjfgvpiadbbzfp6ofy.h"
extern "C"  void kernel(const float* in_ptr0,
                       const float* in_ptr1,
                       const float* in_ptr2,
                       float* out_ptr0,
                       float* out_ptr1,
                       float* out_ptr2,
                       const int64_t ks0,
                       const int64_t ks1)
{
    #pragma omp parallel num_threads(112)
    {
        int tid = omp_get_thread_num();
        {
            #pragma omp for collapse(2)
            for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(ks0); x0+=static_cast<int64_t>(1L))
            {
                for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(32L); x1+=static_cast<int64_t>(1L))
                {
                    {
                        Welford<float> tmp_acc0 = Welford<float>();
                        Welford<at::vec::Vectorized<float>> tmp_acc0_vec = Welford<at::vec::Vectorized<float>>();
                        Welford<at::vec::Vectorized<float>> masked_tmp_acc0_vec = Welford<at::vec::Vectorized<float>>();
                        static WeightRecp<at::vec::Vectorized<float>> wrecps0(static_cast<int64_t>(c10::div_floor_integer(static_cast<int64_t>((15L*(static_cast<int64_t>(ks1*ks1)))), static_cast<int64_t>(8L))));
                        for(int64_t x2=static_cast<int64_t>(0L); x2<static_cast<int64_t>(static_cast<int64_t>(ks1*ks1)); x2+=static_cast<int64_t>(1L))
                        {
                            for(int64_t x3=static_cast<int64_t>(0L); x3<static_cast<int64_t>(16L); x3+=static_cast<int64_t>(16L))
                            {
                                auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x3 + (30L*x1) + (960L*x2) + (960L*x0*(static_cast<int64_t>(ks1*ks1)))), static_cast<int64_t>(16));
                                tmp_acc0_vec = welford_combine(tmp_acc0_vec, tmp0, &wrecps0);
                            }
                            for(int64_t x3=static_cast<int64_t>(16L); x3<static_cast<int64_t>(30L); x3+=static_cast<int64_t>(14L))
                            {
                                auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x3 + (30L*x1) + (960L*x2) + (960L*x0*(static_cast<int64_t>(ks1*ks1)))), static_cast<int64_t>(14L));
                                masked_tmp_acc0_vec = welford_combine(masked_tmp_acc0_vec, tmp0, static_cast<int64_t>(14L), &wrecps0);
                            }
                        }
                        tmp_acc0 = welford_combine(tmp_acc0, welford_vec_reduce_all(masked_tmp_acc0_vec));
                        tmp_acc0 = welford_combine(tmp_acc0, welford_vec_reduce_all(tmp_acc0_vec));
                        out_ptr0[static_cast<int64_t>(x1 + (32L*x0))] = static_cast<float>(tmp_acc0.mean);
                        out_ptr1[static_cast<int64_t>(x1 + (32L*x0))] = static_cast<float>(tmp_acc0.m2);
                    }
                }
            }
        }
        {
            #pragma omp for collapse(2)
            for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(ks0); x0+=static_cast<int64_t>(1L))
            {
                for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(ks1); x1+=static_cast<int64_t>(1L))
                {
                    #pragma GCC ivdep
                    for(int64_t x2=static_cast<int64_t>(0L); x2<static_cast<int64_t>(ks1); x2+=static_cast<int64_t>(1L))
                    {
                        #pragma GCC ivdep
                        for(int64_t x3=static_cast<int64_t>(0L); x3<static_cast<int64_t>(960L); x3+=static_cast<int64_t>(1L))
                        {
                            auto tmp0 = in_ptr0[static_cast<int64_t>(x3 + (960L*x2) + (960L*ks1*x1) + (960L*x0*(static_cast<int64_t>(ks1*ks1))))];
                            auto tmp1 = out_ptr0[static_cast<int64_t>((32L*x0) + (c10::div_floor_integer(static_cast<int64_t>(x3), static_cast<int64_t>(30L))))];
                            auto tmp3 = out_ptr1[static_cast<int64_t>((32L*x0) + (c10::div_floor_integer(static_cast<int64_t>(x3), static_cast<int64_t>(30L))))];
                            auto tmp11 = in_ptr1[static_cast<int64_t>(x3)];
                            auto tmp13 = in_ptr2[static_cast<int64_t>(x3)];
                            auto tmp2 = decltype(tmp0)(tmp0 - tmp1);
                            auto tmp4 = 30L*(static_cast<int64_t>(ks1*ks1));
                            auto tmp5 = c10::convert<float>(tmp4);
                            auto tmp6 = tmp3 / tmp5;
                            auto tmp7 = static_cast<float>(1e-05);
                            auto tmp8 = decltype(tmp6)(tmp6 + tmp7);
                            auto tmp9 = 1 / std::sqrt(tmp8);
                            auto tmp10 = decltype(tmp2)(tmp2 * tmp9);
                            auto tmp12 = decltype(tmp10)(tmp10 * tmp11);
                            auto tmp14 = decltype(tmp12)(tmp12 + tmp13);
                            out_ptr2[static_cast<int64_t>(x3 + (960L*x2) + (960L*x1*(c10::div_floor_integer(static_cast<int64_t>((static_cast<int64_t>(ks1*ks1))), static_cast<int64_t>(ks1)))) + (960L*ks1*x0*(c10::div_floor_integer(static_cast<int64_t>((static_cast<int64_t>(ks1*ks1))), static_cast<int64_t>(ks1)))))] = tmp14;
                        }
                    }
                }
            }
        }
    }
}
''')


async_compile.wait(globals())
del async_compile

def call(args):
    arg0_1, arg1_1, arg2_1, arg3_1, arg4_1 = args
    args.clear()
    s0 = arg2_1
    s2 = arg3_1
    assert_size_stride(arg0_1, (960, ), (1, ))
    assert_size_stride(arg1_1, (960, ), (1, ))
    assert_size_stride(arg4_1, (s0, 960, s2, s2), (960*(s2*s2), 1, 960*s2, 960))
    buf0 = empty_strided_cpu((s0, 32, 1, 1), (32, 1, 32*s0, 32*s0), torch.float32)
    buf1 = empty_strided_cpu((s0, 32, 1, 1), (32, 1, 32*s0, 32*s0), torch.float32)
    buf3 = empty_strided_cpu((s0, 960, s2, s2), (960*s2*((s2*s2) // s2), 1, 960*((s2*s2) // s2), 960), torch.float32)
    cpp_fused_native_group_norm_0(arg4_1, arg0_1, arg1_1, buf0, buf1, buf3, s0, s2)
    del arg0_1
    del arg1_1
    del arg4_1
    return (buf3, )

After:

cpp_fused_native_group_norm_0 = async_compile.cpp_pybinding(['const float*', 'const float*', 'const float*', 'float*', 'float*', 'float*', 'const int64_t', 'const int64_t'], '''
#include "/tmp/torchinductor_jiayisun/32/c32dcqa3qidvmunis4lucp3dhoicleq5qjfjfgvpiadbbzfp6ofy.h"
extern "C"  void kernel(const float* in_ptr0,
                       const float* in_ptr1,
                       const float* in_ptr2,
                       float* out_ptr0,
                       float* out_ptr1,
                       float* out_ptr2,
                       const int64_t ks0,
                       const int64_t ks1)
{
    #pragma omp parallel num_threads(112)
    {
        int tid = omp_get_thread_num();
        {
            #pragma omp for collapse(2)
            for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(ks0); x0+=static_cast<int64_t>(1L))
            {
                for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(32L); x1+=static_cast<int64_t>(1L))
                {
                    {
                        Welford<float> tmp_acc0 = Welford<float>();
                        Welford<at::vec::Vectorized<float>> tmp_acc0_vec = Welford<at::vec::Vectorized<float>>();
                        Welford<at::vec::Vectorized<float>> masked_tmp_acc0_vec = Welford<at::vec::Vectorized<float>>();
                        static WeightRecp<at::vec::Vectorized<float>> wrecps0(static_cast<int64_t>(c10::div_floor_integer(static_cast<int64_t>((15L*(static_cast<int64_t>(ks1*ks1)))), static_cast<int64_t>(8L))));
                        for(int64_t x2=static_cast<int64_t>(0L); x2<static_cast<int64_t>(static_cast<int64_t>(ks1*ks1)); x2+=static_cast<int64_t>(1L))
                        {
                            for(int64_t x3=static_cast<int64_t>(0L); x3<static_cast<int64_t>(16L); x3+=static_cast<int64_t>(16L))
                            {
                                auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x3 + (30L*x1) + (960L*x2) + (960L*x0*(static_cast<int64_t>(ks1*ks1)))), static_cast<int64_t>(16));
                                tmp_acc0_vec = welford_combine(tmp_acc0_vec, tmp0, &wrecps0);
                            }
                            for(int64_t x3=static_cast<int64_t>(16L); x3<static_cast<int64_t>(30L); x3+=static_cast<int64_t>(14L))
                            {
                                auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x3 + (30L*x1) + (960L*x2) + (960L*x0*(static_cast<int64_t>(ks1*ks1)))), static_cast<int64_t>(14L));
                                masked_tmp_acc0_vec = welford_combine(masked_tmp_acc0_vec, tmp0, static_cast<int64_t>(14L), &wrecps0);
                            }
                        }
                        tmp_acc0 = welford_combine(tmp_acc0, welford_vec_reduce_all(masked_tmp_acc0_vec));
                        tmp_acc0 = welford_combine(tmp_acc0, welford_vec_reduce_all(tmp_acc0_vec));
                        out_ptr0[static_cast<int64_t>(x1 + (32L*x0))] = static_cast<float>(tmp_acc0.mean);
                        out_ptr1[static_cast<int64_t>(x1 + (32L*x0))] = static_cast<float>(tmp_acc0.m2);
                    }
                }
            }
        }
        {
            #pragma omp for collapse(2)
            for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(ks0); x0+=static_cast<int64_t>(1L))
            {
                for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(ks1); x1+=static_cast<int64_t>(1L))
                {
                    #pragma GCC ivdep
                    for(int64_t x2=static_cast<int64_t>(0L); x2<static_cast<int64_t>(ks1); x2+=static_cast<int64_t>(1L))
                    {
                        #pragma GCC ivdep
                        for(int64_t x3=static_cast<int64_t>(0L); x3<static_cast<int64_t>(32L); x3+=static_cast<int64_t>(1L))
                        {
                            for(int64_t x4=static_cast<int64_t>(0L); x4<static_cast<int64_t>(16L); x4+=static_cast<int64_t>(16L))
                            {
                                auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x4 + (30L*x3) + (960L*x2) + (960L*ks1*x1) + (960L*x0*(static_cast<int64_t>(ks1*ks1)))), static_cast<int64_t>(16));
                                auto tmp1 = out_ptr0[static_cast<int64_t>(x3 + (32L*x0))];
                                auto tmp4 = out_ptr1[static_cast<int64_t>(x3 + (32L*x0))];
                                auto tmp13 = at::vec::Vectorized<float>::loadu(in_ptr1 + static_cast<int64_t>(x4 + (30L*x3)), static_cast<int64_t>(16));
                                auto tmp15 = at::vec::Vectorized<float>::loadu(in_ptr2 + static_cast<int64_t>(x4 + (30L*x3)), static_cast<int64_t>(16));
                                auto tmp2 = at::vec::Vectorized<float>(tmp1);
                                auto tmp3 = tmp0 - tmp2;
                                auto tmp5 = 30L*(static_cast<int64_t>(ks1*ks1));
                                auto tmp6 = c10::convert<float>(tmp5);
                                auto tmp7 = tmp4 / tmp6;
                                auto tmp8 = static_cast<float>(1e-05);
                                auto tmp9 = decltype(tmp7)(tmp7 + tmp8);
                                auto tmp10 = 1 / std::sqrt(tmp9);
                                auto tmp11 = at::vec::Vectorized<float>(tmp10);
                                auto tmp12 = tmp3 * tmp11;
                                auto tmp14 = tmp12 * tmp13;
                                auto tmp16 = tmp14 + tmp15;
                                tmp16.store(out_ptr2 + static_cast<int64_t>(x4 + (30L*x3) + (960L*x2) + (960L*x1*(c10::div_floor_integer(static_cast<int64_t>((static_cast<int64_t>(ks1*ks1))), static_cast<int64_t>(ks1)))) + (960L*ks1*x0*(c10::div_floor_integer(static_cast<int64_t>((static_cast<int64_t>(ks1*ks1))), static_cast<int64_t>(ks1))))));
                            }
                            for(int64_t x4=static_cast<int64_t>(16L); x4<static_cast<int64_t>(30L); x4+=static_cast<int64_t>(14L))
                            {
                                auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x4 + (30L*x3) + (960L*x2) + (960L*ks1*x1) + (960L*x0*(static_cast<int64_t>(ks1*ks1)))), static_cast<int64_t>(14L));
                                auto tmp1 = out_ptr0[static_cast<int64_t>(x3 + (32L*x0))];
                                auto tmp4 = out_ptr1[static_cast<int64_t>(x3 + (32L*x0))];
                                auto tmp13 = at::vec::Vectorized<float>::loadu(in_ptr1 + static_cast<int64_t>(x4 + (30L*x3)), static_cast<int64_t>(14L));
                                auto tmp15 = at::vec::Vectorized<float>::loadu(in_ptr2 + static_cast<int64_t>(x4 + (30L*x3)), static_cast<int64_t>(14L));
                                auto tmp2 = at::vec::Vectorized<float>(tmp1);
                                auto tmp3 = tmp0 - tmp2;
                                auto tmp5 = 30L*(static_cast<int64_t>(ks1*ks1));
                                auto tmp6 = c10::convert<float>(tmp5);
                                auto tmp7 = tmp4 / tmp6;
                                auto tmp8 = static_cast<float>(1e-05);
                                auto tmp9 = decltype(tmp7)(tmp7 + tmp8);
                                auto tmp10 = 1 / std::sqrt(tmp9);
                                auto tmp11 = at::vec::Vectorized<float>(tmp10);
                                auto tmp12 = tmp3 * tmp11;
                                auto tmp14 = tmp12 * tmp13;
                                auto tmp16 = tmp14 + tmp15;
                                tmp16.store(out_ptr2 + static_cast<int64_t>(x4 + (30L*x3) + (960L*x2) + (960L*x1*(c10::div_floor_integer(static_cast<int64_t>((static_cast<int64_t>(ks1*ks1))), static_cast<int64_t>(ks1)))) + (960L*ks1*x0*(c10::div_floor_integer(static_cast<int64_t>((static_cast<int64_t>(ks1*ks1))), static_cast<int64_t>(ks1))))), static_cast<int64_t>(14L));
                            }
                        }
                    }
                }
            }
        }
    }
}
''')


async_compile.wait(globals())
del async_compile

def call(args):
    arg0_1, arg1_1, arg2_1, arg3_1, arg4_1 = args
    args.clear()
    s0 = arg2_1
    s2 = arg3_1
    assert_size_stride(arg0_1, (960, ), (1, ))
    assert_size_stride(arg1_1, (960, ), (1, ))
    assert_size_stride(arg4_1, (s0, 960, s2, s2), (960*(s2*s2), 1, 960*s2, 960))
    buf0 = empty_strided_cpu((s0, 32, 1, 1), (32, 1, 32*s0, 32*s0), torch.float32)
    buf1 = empty_strided_cpu((s0, 32, 1, 1), (32, 1, 32*s0, 32*s0), torch.float32)
    buf3 = empty_strided_cpu((s0, 960, s2, s2), (960*s2*((s2*s2) // s2), 1, 960*((s2*s2) // s2), 960), torch.float32)
    cpp_fused_native_group_norm_0(arg4_1, arg0_1, arg1_1, buf0, buf1, buf3, s0, s2)
    del arg0_1
    del arg1_1
    del arg4_1
    return (buf3, )

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Sep 6, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/135335

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 3c82af4 with merge base bf68e16 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

jiayisunx added a commit that referenced this pull request Sep 6, 2024
ghstack-source-id: 2172c23
Pull Request resolved: #135335
@jiayisunx jiayisunx marked this pull request as draft September 6, 2024 07:57
[ghstack-poisoned]
jiayisunx added a commit that referenced this pull request Sep 9, 2024
ghstack-source-id: 1b79f14
Pull Request resolved: #135335
@jiayisunx jiayisunx added topic: not user facing topic category ciflow/trunk Trigger trunk jobs on your pull request labels Sep 9, 2024
@jiayisunx jiayisunx marked this pull request as ready for review September 10, 2024 00:56
[ghstack-poisoned]
jiayisunx added a commit that referenced this pull request Sep 11, 2024
ghstack-source-id: 71b82ef
Pull Request resolved: #135335
for div_expr in expr.find(FloorDiv):
if any(div_expr.has(var) for var in original_body.iter_vars):
num_div += 1
if num_div > 1:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible that the div variable and the divider are the same even though there are multiple divs?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While I have not encountered this situation, it is possible and I have further relaxed the conditions to support it. Thanks!

[ghstack-poisoned]
jiayisunx added a commit that referenced this pull request Sep 13, 2024
ghstack-source-id: 106e247
Pull Request resolved: #135335
@jiayisunx jiayisunx requested a review from jgong5 September 13, 2024 09:19
@jiayisunx jiayisunx requested a review from jansel September 18, 2024 10:02
@jiayisunx
Copy link
Collaborator Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

zou3519 added a commit that referenced this pull request Sep 25, 2024
Revert "[PT2][Inductor][Optmus] fix test_pad_mm_bf16 and reland to fix long computation kernel (#136349)"

This reverts commit e184391.

Revert "Fix clang-tidy warnings in torch/csrc/lazy (#134655)"

This reverts commit 0287146.

Revert "Remove duplicate line (#136383)"

This reverts commit 0b91e7e.

Revert "[TF32] Account for TF32 in `test_conv_double_backward` (#135716)"

This reverts commit 29f7b8d.

Revert "Fix `Vectorized<double>::next_after` SVE compilation (#136388)"

This reverts commit 7936584.

Revert "Upgrade pybind11 API calls for 3.13t (#136370)"

This reverts commit 067d203.

Revert "[AOTI][Tooling] Filter out kernels based off lowercase names (#135395)"

This reverts commit 1a10751.

Revert "Add decomps for max_unpool (#133146)"

This reverts commit 0c936c3.

Revert "add TORCH_CUDA_CPP_API for AutoNcclGroup (#130012)"

This reverts commit 293fccf.

Revert "Use cpython declaration of _PyWeakref_ClearRef (#136300)"

This reverts commit d2455b9.

Revert "fix mypi in utils/_sympy/functions.py (#136339)"

This reverts commit 7f9c064.

Revert "[Inductor] Fix test_profiler_mark_wrapper_call_cuda_cuda_wrapper (#136356)"

This reverts commit f53a0f9.

Revert "Add more distributed examples (#130427)"

This reverts commit 5997354.

Revert "return instead of using skipTest (#136244)"

This reverts commit 29affa6.

Reapply "[PT2/Profiler] Add Context Info to Torch-Compiled Regions (#132765)"

This reverts commit 783c5ba.

Revert "Enable torch build with SLEEF on ARM by default (#133339)"

This reverts commit 4842f0f.

Revert "[inductor] Relax the conditions for loop split (#135335)"

This reverts commit 687e5cf.

[ghstack-poisoned]
zou3519 added a commit that referenced this pull request Sep 25, 2024
Revert "[PT2][Inductor][Optmus] fix test_pad_mm_bf16 and reland to fix long computation kernel (#136349)"

This reverts commit e184391.

Revert "Fix clang-tidy warnings in torch/csrc/lazy (#134655)"

This reverts commit 0287146.

Revert "Remove duplicate line (#136383)"

This reverts commit 0b91e7e.

Revert "[TF32] Account for TF32 in `test_conv_double_backward` (#135716)"

This reverts commit 29f7b8d.

Revert "Fix `Vectorized<double>::next_after` SVE compilation (#136388)"

This reverts commit 7936584.

Revert "Upgrade pybind11 API calls for 3.13t (#136370)"

This reverts commit 067d203.

Revert "[AOTI][Tooling] Filter out kernels based off lowercase names (#135395)"

This reverts commit 1a10751.

Revert "Add decomps for max_unpool (#133146)"

This reverts commit 0c936c3.

Revert "add TORCH_CUDA_CPP_API for AutoNcclGroup (#130012)"

This reverts commit 293fccf.

Revert "Use cpython declaration of _PyWeakref_ClearRef (#136300)"

This reverts commit d2455b9.

Revert "fix mypi in utils/_sympy/functions.py (#136339)"

This reverts commit 7f9c064.

Revert "[Inductor] Fix test_profiler_mark_wrapper_call_cuda_cuda_wrapper (#136356)"

This reverts commit f53a0f9.

Revert "Add more distributed examples (#130427)"

This reverts commit 5997354.

Revert "return instead of using skipTest (#136244)"

This reverts commit 29affa6.

Reapply "[PT2/Profiler] Add Context Info to Torch-Compiled Regions (#132765)"

This reverts commit 783c5ba.

Revert "Enable torch build with SLEEF on ARM by default (#133339)"

This reverts commit 4842f0f.

Revert "[inductor] Relax the conditions for loop split (#135335)"

This reverts commit 687e5cf.

ghstack-source-id: b0fb91e
Pull Request resolved: #136668
@github-actions github-actions bot deleted the gh/jiayisunx/28/head branch October 21, 2024 02:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants