Skip to content

Conversation

@jgong5
Copy link
Collaborator

@jgong5 jgong5 commented Jun 21, 2024

Stack from ghstack (oldest at bottom):

This PR makes sure the current kernel is used for generating CSE variables when nested kernel codegen is involved, e.g., nested CppKernel is used to generate epilogue of CppTemplateKernel. Without the fix, the epilogue with indirect indexing would fail to run.

pytest -k test_linear_with_embedding_bias_False_cpu test_cpu_select_algorithm.py

Epilogue code Before:

                {
                    #pragma GCC ivdep
                    for(long x0=static_cast<long>(0L); x0<static_cast<long>(m_end + ((-1L)*m_start)); x0+=static_cast<long>(1L))
                    {
                        for(long x1=static_cast<long>(0L); x1<static_cast<long>(16L*(c10::div_floor_integer(N0, 16L))); x1+=static_cast<long>(16L))
                        {
                            auto tmp0 = in_ptr2[static_cast<long>(m_start + x0)];
                            auto tmp11 = at::vec::Vectorized<float>::loadu(local_acc_buf + static_cast<long>(x1 + (N0*x0)), 16);
                            auto tmp1 = 64L;
                            auto tmp2 = c10::convert<int64_t>(tmp1);
                            auto tmp3 = decltype(tmp0)(tmp0 + tmp2);
                            auto tmp4 = tmp0 ? tmp3 : tmp0;
                            auto tmp5 = decltype(tmp4)(tmp4 + tmp2);
                            auto tmp6 = tmp1 ? tmp5 : tmp4;
                            auto tmp7 = tmp6;
                            auto tmp8 = c10::convert<int64_t>(tmp7);
                            TORCH_CHECK((0 <= tmp8) & (tmp8 < 64L), "index out of bounds: 0 <= tmp8 < 64L");
                            auto tmp10 = at::vec::Vectorized<float>::loadu(in_ptr3 + static_cast<long>(n_start + x1 + (384L*tmp6)), 16);
                            auto tmp12 = (tmp11);
                            auto tmp13 = tmp10 + tmp12;
                            tmp13.store(Y + static_cast<long>(n_start + x1 + (384L*m_start) + (384L*x0)));
                        }
                        #pragma omp simd simdlen(8) 
                        for(long x1=static_cast<long>(16L*(c10::div_floor_integer(N0, 16L))); x1<static_cast<long>(N0); x1+=static_cast<long>(1L))
                        {
                            auto tmp0 = in_ptr2[static_cast<long>(m_start + x0)];
                            auto tmp11 = local_acc_buf[static_cast<long>(x1 + (N0*x0))];
                            auto tmp1 = 64L;
                            auto tmp2 = c10::convert<int64_t>(tmp1);
                            auto tmp3 = decltype(tmp0)(tmp0 + tmp2);
                            auto tmp4 = tmp0 ? tmp3 : tmp0;
                            auto tmp5 = decltype(tmp4)(tmp4 + tmp2);
                            auto tmp6 = tmp1 ? tmp5 : tmp4;
                            auto tmp7 = tmp6;
                            auto tmp8 = c10::convert<int64_t>(tmp7);
                            TORCH_CHECK((0 <= tmp8) & (tmp8 < 64L), "index out of bounds: 0 <= tmp8 < 64L");
                            TORCH_CHECK((0 <= tmp8) & (tmp8 < 64L), "index out of bounds: 0 <= tmp8 < 64L");
                            auto tmp10 = in_ptr3[static_cast<long>(n_start + x1 + (384L*tmp6))];
                            auto tmp12 = c10::convert<float>(tmp11);
                            auto tmp13 = decltype(tmp10)(tmp10 + tmp12);
                            Y[static_cast<long>(n_start + x1 + (384L*m_start) + (384L*x0))] = tmp13;
                        }
                    }
                }

Epilogue code After:

                {
                    #pragma GCC ivdep
                    for(long x0=static_cast<long>(0L); x0<static_cast<long>(m_end + ((-1L)*m_start)); x0+=static_cast<long>(1L))
                    {
                        for(long x1=static_cast<long>(0L); x1<static_cast<long>(16L*(c10::div_floor_integer(N0, 16L))); x1+=static_cast<long>(16L))
                        {
                            auto tmp0 = in_ptr2[static_cast<long>(m_start + x0)];
                            auto tmp13 = at::vec::Vectorized<float>::loadu(local_acc_buf + static_cast<long>(x1 + (N0*x0)), 16);
                            auto tmp1 = 64L;
                            auto tmp2 = c10::convert<int64_t>(tmp1);
                            auto tmp3 = decltype(tmp0)(tmp0 + tmp2);
                            auto tmp4 = tmp0 < 0;
                            auto tmp5 = tmp4 ? tmp3 : tmp0;
                            auto tmp6 = decltype(tmp5)(tmp5 + tmp2);
                            auto tmp7 = tmp5 < 0;
                            auto tmp8 = tmp7 ? tmp6 : tmp5;
                            auto tmp9 = tmp8;
                            auto tmp10 = c10::convert<int64_t>(tmp9);
                            TORCH_CHECK((0 <= tmp10) & (tmp10 < 64L), "index out of bounds: 0 <= tmp10 < 64L");
                            auto tmp12 = at::vec::Vectorized<float>::loadu(in_ptr3 + static_cast<long>(n_start + x1 + (384L*tmp8)), 16);
                            auto tmp14 = (tmp13);
                            auto tmp15 = tmp12 + tmp14;
                            tmp15.store(Y + static_cast<long>(n_start + x1 + (384L*m_start) + (384L*x0)));
                        }
                        #pragma omp simd simdlen(8) 
                        for(long x1=static_cast<long>(16L*(c10::div_floor_integer(N0, 16L))); x1<static_cast<long>(N0); x1+=static_cast<long>(1L))
                        {
                            auto tmp0 = in_ptr2[static_cast<long>(m_start + x0)];
                            auto tmp13 = local_acc_buf[static_cast<long>(x1 + (N0*x0))];
                            auto tmp1 = 64L;
                            auto tmp2 = c10::convert<int64_t>(tmp1);
                            auto tmp3 = decltype(tmp0)(tmp0 + tmp2);
                            auto tmp4 = tmp0 < 0;
                            auto tmp5 = tmp4 ? tmp3 : tmp0;
                            auto tmp6 = decltype(tmp5)(tmp5 + tmp2);
                            auto tmp7 = tmp5 < 0;
                            auto tmp8 = tmp7 ? tmp6 : tmp5;
                            auto tmp9 = tmp8;
                            auto tmp10 = c10::convert<int64_t>(tmp9);
                            TORCH_CHECK((0 <= tmp10) & (tmp10 < 64L), "index out of bounds: 0 <= tmp10 < 64L");
                            TORCH_CHECK((0 <= tmp10) & (tmp10 < 64L), "index out of bounds: 0 <= tmp10 < 64L");
                            auto tmp12 = in_ptr3[static_cast<long>(n_start + x1 + (384L*tmp8))];
                            auto tmp14 = c10::convert<float>(tmp13);
                            auto tmp15 = decltype(tmp12)(tmp12 + tmp14);
                            Y[static_cast<long>(n_start + x1 + (384L*m_start) + (384L*x0))] = tmp15;
                        }
                    }
                }

cc @voznesenskym @penguinwu @EikanWang @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Jun 21, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/129223

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (2 Unrelated Failures)

As of commit 4e309e4 with merge base c012013 (image):

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

UNSTABLE - The following job failed but was likely due to flakiness present on trunk and has been marked as unstable:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

[ghstack-poisoned]
if dtype == torch.half or dtype == torch.bfloat16:
atol, rtol = 1e-2, 1e-2
with patch.object(select_algorithm, "VERIFY", dict(atol=atol, rtol=rtol)):
with verify(dtype) as (atol, rtol):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Simplify the checking with tolerance code here.

Comment on lines +1741 to +1743
csevar = V.kernel.cse.generate(
V.kernel.compute, v, bounds=bounds
)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sure we are working on the current kernel when the nested kernel is being generated, e.g., CppKernel generated as the epilogue of CppTemplateKernel.

@jgong5 jgong5 requested a review from leslie-fang-intel June 21, 2024 09:02
@jgong5 jgong5 added the topic: not user facing topic category label Jun 21, 2024
[ghstack-poisoned]
jgong5 pushed a commit that referenced this pull request Jun 21, 2024
Copy link
Collaborator

@leslie-fang-intel leslie-fang-intel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

def inner(*args, **kwargs):
bounds = CSEProxy._bound_variable(name, *args, **kwargs)

value = getattr(parent_handler, name)(*args, **kwargs) # type: ignore[has-type]
Copy link
Collaborator

@leslie-fang-intel leslie-fang-intel Jun 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix. For ops name like lt, getattr(parent_handler, name) will back trace until MockHandler to do the codegen. So, for this case of nested kernel, it back trace to use instance of CSEProxy binding with parent kernel (CppTemplateKernel in this case).

Attached the scalar op list using MockHandler to do the codegen:

for name, format_string in {
"add": "{} + {}",
"sub": "{} - {}",
"mul": "{} * {}",
"floordiv": "{} // {}",
"truediv": "{} / {}",
"mod": "{} % {}", # careful, depending on target semantics varies
"pow": "{} ** {}",
"lshift": "{} << {}",
"rshift": "{} >> {}",
"and_": "{} & {}",
"or_": "{} | {}",
"xor": "{} ^ {}",
"eq": "{} == {}",
"ne": "{} != {}",
"lt": "{} < {}",
"gt": "{} > {}",
"le": "{} <= {}",
"ge": "{} >= {}",
"neg": "-{}",
}.items():
setattr(cls, name, make_handler(format_string))

@jgong5 jgong5 requested review from jansel and peterbell10 June 22, 2024 04:03
[ghstack-poisoned]
jgong5 pushed a commit that referenced this pull request Jun 22, 2024
[ghstack-poisoned]
jgong5 pushed a commit that referenced this pull request Jun 25, 2024
@jgong5
Copy link
Collaborator Author

jgong5 commented Jun 25, 2024

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Jun 25, 2024
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@github-actions github-actions bot deleted the gh/jgong5/55/head branch July 26, 2024 01:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants