Skip to content

Commit 5b51a4f

Browse files
committed
Update on "Support generic stream/event on XPU backend"
# Motivation According to [#123611](#123611), we support generic stream/event on XPU backend. # Additional Context new method/attribute on `torch.Event` - torch.Event.event_id - torch.Event.elapsed_time - torch.Event.synchronize new method on `c10::Event` - c10.Event.event_id - c10.Event.elapsed_time - c10.Event.synchronize cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 gujinghui EikanWang fengyuan14 [ghstack-poisoned]
2 parents 76999d5 + c295579 commit 5b51a4f

File tree

22 files changed

+246
-1403
lines changed

22 files changed

+246
-1403
lines changed

.github/ci_commit_pins/vision.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
06ad737628abc3a1e617571dc03cbdd5b36ea96a
1+
d23a6e1664d20707c11781299611436e1f0c104f

aten/src/ATen/cpu/Utils.cpp

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,22 +5,6 @@
55

66
namespace at::cpu {
77

8-
bool is_cpu_support_avx2() {
9-
#if !defined(__s390x__) && !defined(__powerpc__)
10-
return cpuinfo_initialize() && cpuinfo_has_x86_avx2();
11-
#else
12-
return false;
13-
#endif
14-
}
15-
16-
bool is_cpu_support_avx512() {
17-
#if !defined(__s390x__) && !defined(__powerpc__)
18-
return cpuinfo_initialize() && cpuinfo_has_x86_avx512f() && cpuinfo_has_x86_avx512vl() && cpuinfo_has_x86_avx512bw() && cpuinfo_has_x86_avx512dq();
19-
#else
20-
return false;
21-
#endif
22-
}
23-
248
bool is_cpu_support_vnni() {
259
#if !defined(__s390x__) && !defined(__powerpc__)
2610
return cpuinfo_initialize() && cpuinfo_has_x86_avx512vnni();

aten/src/ATen/cpu/Utils.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,6 @@
44

55
namespace at::cpu {
66

7-
TORCH_API bool is_cpu_support_avx2();
8-
TORCH_API bool is_cpu_support_avx512();
9-
107
// Detect if CPU support Vector Neural Network Instruction.
118
TORCH_API bool is_cpu_support_vnni();
129

c10/xpu/impl/XPUGuardImpl.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,8 +158,7 @@ struct XPUGuardImpl final : public c10::impl::DeviceGuardImplInterface {
158158
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
159159
if (C10_UNLIKELY(interp)) {
160160
(*interp)->trace_gpu_event_synchronization(
161-
c10::kXPU,
162-
reinterpret_cast<uintptr_t>(xpu_event));
161+
c10::kXPU, reinterpret_cast<uintptr_t>(xpu_event));
163162
}
164163
xpu_event->wait_and_throw();
165164
}

test/distributed/_tensor/test_dtensor_compile.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def forward(self, input):
6060

6161

6262
def extract_graph(fx_g, _, graph_cell):
63-
graph_cell[0] = fx_g
63+
graph_cell[0] = fx_g.code
6464
return fx_g
6565

6666

@@ -481,6 +481,32 @@ def fn(x_dt):
481481
res = opt_fn(x_dt)
482482
self.assertEqual(ref, res)
483483

484+
def test_graph_input_is_async(self):
485+
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
486+
487+
def fn(x):
488+
return x.sin().sin()
489+
490+
opt_fn = torch.compile(fn, backend=aot_eager_graph, fullgraph=True)
491+
492+
x = torch.randn(4, 4, requires_grad=True)
493+
x_dt = DTensor.from_local(x, mesh, [Shard(0)], run_check=False)
494+
x2 = x_dt.redistribute(mesh, [Replicate()], async_op=True)
495+
x2 = x2.to_local()
496+
out = opt_fn(x2)
497+
# The important part: we get a wait_tensor() in the graph.
498+
# At runtime, the input to the graph is an AsyncCollectiveTensor,
499+
# and inside the graph we need to issue a wait() to synchronize.
500+
self.assertExpectedInline(
501+
str(fw_graph_cell[0]).strip(),
502+
"""\
503+
def forward(self, primals_1):
504+
wait_tensor = torch.ops._c10d_functional.wait_tensor.default(primals_1)
505+
sin = torch.ops.aten.sin.default(wait_tensor)
506+
sin_1 = torch.ops.aten.sin.default(sin); sin = None
507+
return [sin_1, primals_1, wait_tensor]""",
508+
)
509+
484510
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
485511
def test_dtensor_partial_placement_graph_output(self):
486512
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))

test/inductor/test_debug_trace.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99

1010
import torch
1111
from torch._inductor import config, test_operators
12+
from torch.testing._internal.common_cuda import TEST_CUDA
13+
from torch.utils._triton import has_triton
1214

1315
try:
1416
try:
@@ -168,6 +170,29 @@ def body(self, ops):
168170
# intentionally only cleanup on success so debugging test is easier
169171
shutil.rmtree(filename)
170172

173+
@unittest.skipIf(not TEST_CUDA or not has_triton(), "requires cuda")
174+
def test_debug_multi_tempalte(self):
175+
class ToyModel(torch.nn.Module):
176+
def __init__(self):
177+
super().__init__()
178+
self.l = torch.nn.Linear(100, 100)
179+
self.relu = torch.nn.ReLU()
180+
181+
def forward(self, x):
182+
return self.relu(self.l(x))
183+
184+
# no failure
185+
186+
from torch._inductor.utils import fresh_inductor_cache
187+
188+
with self.assertLogs(
189+
logging.getLogger("torch._inductor.debug"), level=logging.WARNING
190+
), fresh_inductor_cache():
191+
m = ToyModel().to(device="cuda:0")
192+
m = torch.compile(m, mode="max-autotune")
193+
input_tensor = torch.randn(100).to(device="cuda:0")
194+
m(input_tensor)
195+
171196

172197
if __name__ == "__main__":
173198
from torch._inductor.test_case import run_tests

test/inductor/test_torchinductor.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6245,11 +6245,6 @@ def fn(x):
62456245

62466246
self.common(fn, [torch.randn(64, 64)])
62476247

6248-
def test_new_cpp_build_logical(self):
6249-
from torch._inductor.codecache import validate_new_cpp_commands
6250-
6251-
validate_new_cpp_commands()
6252-
62536248
def test_as_strided(self):
62546249
def fn(x):
62556250
return (

test/run_test.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -229,9 +229,7 @@ def __contains__(self, item):
229229
"nn/test_pooling",
230230
"nn/test_convolution", # Doesn't respect set_per_process_memory_fraction, results in OOM for other tests in slow gradcheck
231231
"distributions/test_distributions",
232-
"functorch/test_vmap", # OOM
233232
"test_fx", # gets SIGKILL
234-
"test_dataloader", # frequently hangs for ROCm
235233
"functorch/test_memory_efficient_fusion", # Cause CUDA OOM on ROCm
236234
"test_utils", # OOM
237235
"test_sort_and_select", # OOM

test/test_linalg.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
TEST_WITH_ROCM, IS_FBCODE, IS_REMOTE_GPU, iter_indices,
1919
make_fullrank_matrices_with_distinct_singular_values,
2020
freeze_rng_state, IS_ARM64, IS_SANDCASTLE, TEST_OPT_EINSUM, parametrize, skipIfTorchDynamo,
21-
setBlasBackendsToDefaultFinally, setLinalgBackendsToDefaultFinally)
21+
setBlasBackendsToDefaultFinally, setLinalgBackendsToDefaultFinally, serialTest)
2222
from torch.testing._internal.common_device_type import \
2323
(instantiate_device_type_tests, dtypes, has_cusolver, has_hipsolver,
2424
onlyCPU, skipCUDAIf, skipCUDAIfNoMagma, skipCPUIfNoLapack, precisionOverride,
@@ -2485,6 +2485,7 @@ def run_subtest(actual_rank, matrix_size, batches, device, svd_lowrank, **option
24852485
@precisionOverride({torch.float: 1e-4, torch.cfloat: 2e-4})
24862486
@setLinalgBackendsToDefaultFinally
24872487
@dtypes(*floating_and_complex_types())
2488+
@serialTest()
24882489
def test_svd(self, device, dtype):
24892490
# tests linalg.svd, svd, linalg.svdvals
24902491
make_arg = partial(make_tensor, dtype=dtype, device=device)

torch/_C/_cpu.pyi

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,4 @@ from torch.types import _bool
22

33
# Defined in torch/csrc/cpu/Module.cpp
44

5-
def _is_cpu_support_avx2() -> _bool: ...
6-
def _is_cpu_support_avx512() -> _bool: ...
75
def _is_cpu_support_vnni() -> _bool: ...

0 commit comments

Comments
 (0)