Commit 61a7c83
[Inductor] fix device error for NopKernelSchedulerNode (#141372)
This PR adds device guard support for NopKernelSchedulerNode which may create a tensor. Prior to this PR, we do not codegen device guard for NopKernelSchedulerNode, leading to errors.
Prior to the PR:
```python
def call(args):
arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1 = args
args.clear()
assert_size_stride(arg0_1, (1, 1, 2048, 128), (262144, 262144, 128, 1))
assert_size_stride(arg1_1, (1, 1, 2048, 128), (262144, 262144, 128, 1))
assert_size_stride(arg2_1, (1, 1, 2048, 128), (262144, 262144, 128, 1))
assert_size_stride(arg3_1, (1, 1, 16), (16, 16, 1))
assert_size_stride(arg4_1, (1, 1, 16, 16), (256, 256, 16, 1))
assert_size_stride(arg5_1, (1, 1, 16), (16, 16, 1))
assert_size_stride(arg6_1, (1, 1, 16, 16), (256, 256, 16, 1))
assert_size_stride(arg7_1, (1, 1, 16), (16, 16, 1))
assert_size_stride(arg8_1, (1, 1, 16, 16), (256, 256, 16, 1))
assert_size_stride(arg9_1, (1, 1, 16), (16, 16, 1))
assert_size_stride(arg10_1, (1, 1, 16, 16), (256, 256, 16, 1))
buf0 = empty_strided_cuda((1, 1, 2048), (2048, 2048, 1), torch.float32) # TODO: ERROR here. Should be cuda:1
with torch.cuda._DeviceGuard(1):
torch.cuda.set_device(1)
buf1 = empty_strided_cuda((1, 1, 2048, 128), (262144, 262144, 128, 1), torch.bfloat16)
# Topologically Sorted Source Nodes: [flex_attention], Original ATen: []
stream1 = get_raw_stream(1)
breakpoint()
triton_tem_fused_0.run(arg0_1, arg1_1, arg2_1, buf0, arg3_1, arg4_1, arg5_1, arg6_1, buf1, grid=torch._inductor.kernel.flex_attention.flex_attention_grid(1, 1, 2048, 128, meta0), stream=stream1)
del arg0_1
del arg1_1
del arg2_1
del arg3_1
del arg4_1
del arg5_1
del arg6_1
del buf0
return (buf1, )
```
After the PR:
```python
def call(args):
arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1 = args
args.clear()
assert_size_stride(arg0_1, (1, 1, 2048, 128), (262144, 262144, 128, 1))
assert_size_stride(arg1_1, (1, 1, 2048, 128), (262144, 262144, 128, 1))
assert_size_stride(arg2_1, (1, 1, 2048, 128), (262144, 262144, 128, 1))
assert_size_stride(arg3_1, (1, 1, 16), (16, 16, 1))
assert_size_stride(arg4_1, (1, 1, 16, 16), (256, 256, 16, 1))
assert_size_stride(arg5_1, (1, 1, 16), (16, 16, 1))
assert_size_stride(arg6_1, (1, 1, 16, 16), (256, 256, 16, 1))
assert_size_stride(arg7_1, (1, 1, 16), (16, 16, 1))
assert_size_stride(arg8_1, (1, 1, 16, 16), (256, 256, 16, 1))
assert_size_stride(arg9_1, (1, 1, 16), (16, 16, 1))
assert_size_stride(arg10_1, (1, 1, 16, 16), (256, 256, 16, 1))
with torch.cuda._DeviceGuard(1):
torch.cuda.set_device(1)
buf0 = empty_strided_cuda((1, 1, 2048), (2048, 2048, 1), torch.float32) # New: move into device guard
buf1 = empty_strided_cuda((1, 1, 2048, 128), (262144, 262144, 128, 1), torch.bfloat16)
# Topologically Sorted Source Nodes: [flex_attention], Original ATen: []
stream1 = get_raw_stream(1)
triton_tem_fused_0.run(arg0_1, arg1_1, arg2_1, buf0, arg3_1, arg4_1, arg5_1, arg6_1, buf1, grid=torch._inductor.kernel.flex_attention.flex_attention_grid(1, 1, 2048, 128, meta0), stream=stream1)
del arg0_1
del arg1_1
del arg2_1
del arg3_1
del arg4_1
del arg5_1
del arg6_1
del buf0
return (buf1, )
```
Fixes #141010
Pull Request resolved: #141372
Approved by: https://github.com/eellison1 parent 3fd51e0 commit 61a7c83
File tree
6 files changed
+52
-9
lines changed- test/inductor
- torch/_inductor
6 files changed
+52
-9
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
3303 | 3303 | | |
3304 | 3304 | | |
3305 | 3305 | | |
| 3306 | + | |
| 3307 | + | |
| 3308 | + | |
| 3309 | + | |
| 3310 | + | |
| 3311 | + | |
| 3312 | + | |
| 3313 | + | |
| 3314 | + | |
| 3315 | + | |
| 3316 | + | |
| 3317 | + | |
| 3318 | + | |
| 3319 | + | |
| 3320 | + | |
| 3321 | + | |
| 3322 | + | |
| 3323 | + | |
| 3324 | + | |
| 3325 | + | |
| 3326 | + | |
3306 | 3327 | | |
3307 | 3328 | | |
3308 | 3329 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
2494 | 2494 | | |
2495 | 2495 | | |
2496 | 2496 | | |
| 2497 | + | |
| 2498 | + | |
| 2499 | + | |
| 2500 | + | |
| 2501 | + | |
| 2502 | + | |
| 2503 | + | |
2497 | 2504 | | |
2498 | 2505 | | |
2499 | 2506 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
850 | 850 | | |
851 | 851 | | |
852 | 852 | | |
853 | | - | |
854 | | - | |
| 853 | + | |
| 854 | + | |
| 855 | + | |
| 856 | + | |
| 857 | + | |
| 858 | + | |
855 | 859 | | |
856 | 860 | | |
857 | 861 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
3176 | 3176 | | |
3177 | 3177 | | |
3178 | 3178 | | |
| 3179 | + | |
| 3180 | + | |
3179 | 3181 | | |
3180 | | - | |
| 3182 | + | |
3181 | 3183 | | |
3182 | 3184 | | |
3183 | 3185 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
3033 | 3033 | | |
3034 | 3034 | | |
3035 | 3035 | | |
3036 | | - | |
| 3036 | + | |
3037 | 3037 | | |
3038 | 3038 | | |
3039 | 3039 | | |
| |||
3045 | 3045 | | |
3046 | 3046 | | |
3047 | 3047 | | |
3048 | | - | |
| 3048 | + | |
| 3049 | + | |
| 3050 | + | |
| 3051 | + | |
| 3052 | + | |
| 3053 | + | |
3049 | 3054 | | |
3050 | 3055 | | |
3051 | 3056 | | |
| |||
3059 | 3064 | | |
3060 | 3065 | | |
3061 | 3066 | | |
| 3067 | + | |
3062 | 3068 | | |
3063 | 3069 | | |
3064 | 3070 | | |
| |||
3089 | 3095 | | |
3090 | 3096 | | |
3091 | 3097 | | |
3092 | | - | |
| 3098 | + | |
| 3099 | + | |
| 3100 | + | |
| 3101 | + | |
| 3102 | + | |
| 3103 | + | |
3093 | 3104 | | |
3094 | 3105 | | |
3095 | 3106 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
3511 | 3511 | | |
3512 | 3512 | | |
3513 | 3513 | | |
3514 | | - | |
3515 | | - | |
3516 | | - | |
| 3514 | + | |
3517 | 3515 | | |
3518 | 3516 | | |
3519 | 3517 | | |
| |||
0 commit comments