Commit 2267787
[AOTI] Fix a special case compile time data type codegen for sym int variables (#138106)
Summary:
This change unblocks the CFR AOTI lowering runtime error.
TL;DR:
In this model, one triton kernel expects a scalar input dtype as i64, but getting an i32. The reason is "auto" can infer a smaller data type if the variable get passed in e.g. is i32. thus cause CUDA IMA.
Original problematic kernel: `triton_poi_fused_add_ge_logical_and_logical_or_lt_46_grid_100`. and third input `auto var_402 = u0`.
This diff explicitly specifies it to i64 for all symbolic arguments in compile time for i64 triton kernel inputs, instead of use `auto var_x = {arg}` in cpp wrapper code.
Test Plan:
Verified in FLB locally:
```
PYTORCH_NO_CUDA_MEMORY_CACHING=1 AOT_INDUCTOR_DEBUG_INTERMEDIATE_VALUE_PRINTER=3 TORCH_LOGS="output_code" TORCHINDUCTOR_MAX_AUTOTUNE=1 TORCH_SHOW_CPP_STACKTRACES=1 CUDA_LAUNCH_BLOCKING=1 ~/fbsource/buck-out/v2/gen/fbcode/98e643f8bb44fe9d/hpc/new/models/feed/benchmark/__feed_lower_benchmark__/feed_lower_benchmark.par --skip-eager --skip-flop-estimation --lower-backend="AOT_INDUCTOR" --sync-mode=0 --precision bf16 --output-precision bf16 --lower-presets="ifr_cint;disable_new_lowering_weights;disable_dper_passes:passes=fuse_parallel_linear_no_weight_change" --remove-unexpected-type-cast=False --load="manifold://ads_storage_fblearner/tree/user/facebook/fblearner/predictor/924293663/0/gpu_lowering/input.merge"```
Differential Revision: D644900391 parent 620039c commit 2267787
File tree
2 files changed
+87
-6
lines changed- test/inductor
- torch/_inductor/codegen
2 files changed
+87
-6
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
14 | 14 | | |
15 | 15 | | |
16 | 16 | | |
| 17 | + | |
17 | 18 | | |
18 | 19 | | |
19 | 20 | | |
| |||
3608 | 3609 | | |
3609 | 3610 | | |
3610 | 3611 | | |
| 3612 | + | |
| 3613 | + | |
| 3614 | + | |
| 3615 | + | |
| 3616 | + | |
| 3617 | + | |
| 3618 | + | |
| 3619 | + | |
| 3620 | + | |
| 3621 | + | |
| 3622 | + | |
| 3623 | + | |
| 3624 | + | |
| 3625 | + | |
| 3626 | + | |
| 3627 | + | |
| 3628 | + | |
| 3629 | + | |
| 3630 | + | |
| 3631 | + | |
| 3632 | + | |
| 3633 | + | |
| 3634 | + | |
| 3635 | + | |
| 3636 | + | |
| 3637 | + | |
| 3638 | + | |
| 3639 | + | |
| 3640 | + | |
| 3641 | + | |
| 3642 | + | |
| 3643 | + | |
| 3644 | + | |
| 3645 | + | |
| 3646 | + | |
| 3647 | + | |
| 3648 | + | |
| 3649 | + | |
| 3650 | + | |
| 3651 | + | |
| 3652 | + | |
| 3653 | + | |
| 3654 | + | |
| 3655 | + | |
| 3656 | + | |
| 3657 | + | |
| 3658 | + | |
| 3659 | + | |
3611 | 3660 | | |
3612 | 3661 | | |
3613 | 3662 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1 | 1 | | |
2 | 2 | | |
3 | 3 | | |
4 | | - | |
| 4 | + | |
5 | 5 | | |
6 | 6 | | |
7 | 7 | | |
| |||
286 | 286 | | |
287 | 287 | | |
288 | 288 | | |
289 | | - | |
| 289 | + | |
290 | 290 | | |
291 | | - | |
| 291 | + | |
| 292 | + | |
| 293 | + | |
| 294 | + | |
| 295 | + | |
| 296 | + | |
| 297 | + | |
| 298 | + | |
| 299 | + | |
292 | 300 | | |
293 | 301 | | |
294 | 302 | | |
| |||
312 | 320 | | |
313 | 321 | | |
314 | 322 | | |
| 323 | + | |
| 324 | + | |
| 325 | + | |
| 326 | + | |
| 327 | + | |
| 328 | + | |
| 329 | + | |
| 330 | + | |
| 331 | + | |
| 332 | + | |
| 333 | + | |
315 | 334 | | |
316 | 335 | | |
317 | 336 | | |
318 | 337 | | |
| 338 | + | |
| 339 | + | |
| 340 | + | |
| 341 | + | |
| 342 | + | |
319 | 343 | | |
320 | 344 | | |
321 | 345 | | |
| |||
392 | 416 | | |
393 | 417 | | |
394 | 418 | | |
| 419 | + | |
395 | 420 | | |
396 | 421 | | |
397 | | - | |
398 | | - | |
| 422 | + | |
| 423 | + | |
399 | 424 | | |
400 | 425 | | |
401 | 426 | | |
402 | 427 | | |
403 | 428 | | |
404 | 429 | | |
| 430 | + | |
| 431 | + | |
| 432 | + | |
| 433 | + | |
| 434 | + | |
405 | 435 | | |
406 | | - | |
| 436 | + | |
| 437 | + | |
| 438 | + | |
407 | 439 | | |
408 | 440 | | |
409 | 441 | | |
| |||
0 commit comments