-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[inductor][cpp] support nested kernel with indirect indexing #129223
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/129223
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (2 Unrelated Failures)As of commit 4e309e4 with merge base c012013 ( BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
UNSTABLE - The following job failed but was likely due to flakiness present on trunk and has been marked as unstable:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
| if dtype == torch.half or dtype == torch.bfloat16: | ||
| atol, rtol = 1e-2, 1e-2 | ||
| with patch.object(select_algorithm, "VERIFY", dict(atol=atol, rtol=rtol)): | ||
| with verify(dtype) as (atol, rtol): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Simplify the checking with tolerance code here.
| csevar = V.kernel.cse.generate( | ||
| V.kernel.compute, v, bounds=bounds | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sure we are working on the current kernel when the nested kernel is being generated, e.g., CppKernel generated as the epilogue of CppTemplateKernel.
leslie-fang-intel
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
| def inner(*args, **kwargs): | ||
| bounds = CSEProxy._bound_variable(name, *args, **kwargs) | ||
|
|
||
| value = getattr(parent_handler, name)(*args, **kwargs) # type: ignore[has-type] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the fix. For ops name like lt, getattr(parent_handler, name) will back trace until MockHandler to do the codegen. So, for this case of nested kernel, it back trace to use instance of CSEProxy binding with parent kernel (CppTemplateKernel in this case).
Attached the scalar op list using MockHandler to do the codegen:
pytorch/torch/_inductor/ops_handler.py
Lines 798 to 819 in 5b14943
| for name, format_string in { | |
| "add": "{} + {}", | |
| "sub": "{} - {}", | |
| "mul": "{} * {}", | |
| "floordiv": "{} // {}", | |
| "truediv": "{} / {}", | |
| "mod": "{} % {}", # careful, depending on target semantics varies | |
| "pow": "{} ** {}", | |
| "lshift": "{} << {}", | |
| "rshift": "{} >> {}", | |
| "and_": "{} & {}", | |
| "or_": "{} | {}", | |
| "xor": "{} ^ {}", | |
| "eq": "{} == {}", | |
| "ne": "{} != {}", | |
| "lt": "{} < {}", | |
| "gt": "{} > {}", | |
| "le": "{} <= {}", | |
| "ge": "{} >= {}", | |
| "neg": "-{}", | |
| }.items(): | |
| setattr(cls, name, make_handler(format_string)) |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Stack from ghstack (oldest at bottom):
This PR makes sure the current kernel is used for generating CSE variables when nested kernel codegen is involved, e.g., nested CppKernel is used to generate epilogue of CppTemplateKernel. Without the fix, the epilogue with indirect indexing would fail to run.
pytest -k test_linear_with_embedding_bias_False_cpu test_cpu_select_algorithm.py
Epilogue code Before:
{ #pragma GCC ivdep for(long x0=static_cast<long>(0L); x0<static_cast<long>(m_end + ((-1L)*m_start)); x0+=static_cast<long>(1L)) { for(long x1=static_cast<long>(0L); x1<static_cast<long>(16L*(c10::div_floor_integer(N0, 16L))); x1+=static_cast<long>(16L)) { auto tmp0 = in_ptr2[static_cast<long>(m_start + x0)]; auto tmp11 = at::vec::Vectorized<float>::loadu(local_acc_buf + static_cast<long>(x1 + (N0*x0)), 16); auto tmp1 = 64L; auto tmp2 = c10::convert<int64_t>(tmp1); auto tmp3 = decltype(tmp0)(tmp0 + tmp2); auto tmp4 = tmp0 ? tmp3 : tmp0; auto tmp5 = decltype(tmp4)(tmp4 + tmp2); auto tmp6 = tmp1 ? tmp5 : tmp4; auto tmp7 = tmp6; auto tmp8 = c10::convert<int64_t>(tmp7); TORCH_CHECK((0 <= tmp8) & (tmp8 < 64L), "index out of bounds: 0 <= tmp8 < 64L"); auto tmp10 = at::vec::Vectorized<float>::loadu(in_ptr3 + static_cast<long>(n_start + x1 + (384L*tmp6)), 16); auto tmp12 = (tmp11); auto tmp13 = tmp10 + tmp12; tmp13.store(Y + static_cast<long>(n_start + x1 + (384L*m_start) + (384L*x0))); } #pragma omp simd simdlen(8) for(long x1=static_cast<long>(16L*(c10::div_floor_integer(N0, 16L))); x1<static_cast<long>(N0); x1+=static_cast<long>(1L)) { auto tmp0 = in_ptr2[static_cast<long>(m_start + x0)]; auto tmp11 = local_acc_buf[static_cast<long>(x1 + (N0*x0))]; auto tmp1 = 64L; auto tmp2 = c10::convert<int64_t>(tmp1); auto tmp3 = decltype(tmp0)(tmp0 + tmp2); auto tmp4 = tmp0 ? tmp3 : tmp0; auto tmp5 = decltype(tmp4)(tmp4 + tmp2); auto tmp6 = tmp1 ? tmp5 : tmp4; auto tmp7 = tmp6; auto tmp8 = c10::convert<int64_t>(tmp7); TORCH_CHECK((0 <= tmp8) & (tmp8 < 64L), "index out of bounds: 0 <= tmp8 < 64L"); TORCH_CHECK((0 <= tmp8) & (tmp8 < 64L), "index out of bounds: 0 <= tmp8 < 64L"); auto tmp10 = in_ptr3[static_cast<long>(n_start + x1 + (384L*tmp6))]; auto tmp12 = c10::convert<float>(tmp11); auto tmp13 = decltype(tmp10)(tmp10 + tmp12); Y[static_cast<long>(n_start + x1 + (384L*m_start) + (384L*x0))] = tmp13; } } }Epilogue code After:
{ #pragma GCC ivdep for(long x0=static_cast<long>(0L); x0<static_cast<long>(m_end + ((-1L)*m_start)); x0+=static_cast<long>(1L)) { for(long x1=static_cast<long>(0L); x1<static_cast<long>(16L*(c10::div_floor_integer(N0, 16L))); x1+=static_cast<long>(16L)) { auto tmp0 = in_ptr2[static_cast<long>(m_start + x0)]; auto tmp13 = at::vec::Vectorized<float>::loadu(local_acc_buf + static_cast<long>(x1 + (N0*x0)), 16); auto tmp1 = 64L; auto tmp2 = c10::convert<int64_t>(tmp1); auto tmp3 = decltype(tmp0)(tmp0 + tmp2); auto tmp4 = tmp0 < 0; auto tmp5 = tmp4 ? tmp3 : tmp0; auto tmp6 = decltype(tmp5)(tmp5 + tmp2); auto tmp7 = tmp5 < 0; auto tmp8 = tmp7 ? tmp6 : tmp5; auto tmp9 = tmp8; auto tmp10 = c10::convert<int64_t>(tmp9); TORCH_CHECK((0 <= tmp10) & (tmp10 < 64L), "index out of bounds: 0 <= tmp10 < 64L"); auto tmp12 = at::vec::Vectorized<float>::loadu(in_ptr3 + static_cast<long>(n_start + x1 + (384L*tmp8)), 16); auto tmp14 = (tmp13); auto tmp15 = tmp12 + tmp14; tmp15.store(Y + static_cast<long>(n_start + x1 + (384L*m_start) + (384L*x0))); } #pragma omp simd simdlen(8) for(long x1=static_cast<long>(16L*(c10::div_floor_integer(N0, 16L))); x1<static_cast<long>(N0); x1+=static_cast<long>(1L)) { auto tmp0 = in_ptr2[static_cast<long>(m_start + x0)]; auto tmp13 = local_acc_buf[static_cast<long>(x1 + (N0*x0))]; auto tmp1 = 64L; auto tmp2 = c10::convert<int64_t>(tmp1); auto tmp3 = decltype(tmp0)(tmp0 + tmp2); auto tmp4 = tmp0 < 0; auto tmp5 = tmp4 ? tmp3 : tmp0; auto tmp6 = decltype(tmp5)(tmp5 + tmp2); auto tmp7 = tmp5 < 0; auto tmp8 = tmp7 ? tmp6 : tmp5; auto tmp9 = tmp8; auto tmp10 = c10::convert<int64_t>(tmp9); TORCH_CHECK((0 <= tmp10) & (tmp10 < 64L), "index out of bounds: 0 <= tmp10 < 64L"); TORCH_CHECK((0 <= tmp10) & (tmp10 < 64L), "index out of bounds: 0 <= tmp10 < 64L"); auto tmp12 = in_ptr3[static_cast<long>(n_start + x1 + (384L*tmp8))]; auto tmp14 = c10::convert<float>(tmp13); auto tmp15 = decltype(tmp12)(tmp12 + tmp14); Y[static_cast<long>(n_start + x1 + (384L*m_start) + (384L*x0))] = tmp15; } } }cc @voznesenskym @penguinwu @EikanWang @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang