-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Closed
Labels
high prioritymodule: numpyRelated to numpy support, and also numpy compatibility of our operatorsRelated to numpy support, and also numpy compatibility of our operatorsoncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Milestone
Description
🐛 Describe the bug
torch-2.4.1, numpy-2.0.0 (same error with 2.0.1) errors out when trying to torch.compile np.random.uniform() but OK in numpy-1.x (verified with numpy-1.26.0).
Minimal repro script:
import torch
import numpy as np
# minimal repro extracted from: test/dynamo/test_unspec.py#test_to_tensor()
def f1():
a = np.random.uniform(low=-1, high=1, size=(20, 1))
return torch.tensor([a, a, a, a], dtype=torch.float64, device="cpu")
optimize = torch.compile(backend="inductor", fullgraph=True)
result = optimize(f1)()
print(result)
Error logs
Traceback (most recent call last):
File "/usr/local/google/home/kiuk/tmp/numpy_repro.py", line 11, in <module>
result = optimize(f1)()
File "/usr/local/google/home/kiuk/.pyenv/versions/venv310/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 433, in _fn
return fn(*args, **kwargs)
File "/usr/local/google/home/kiuk/.pyenv/versions/venv310/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 1116, in __call__
return self._torchdynamo_orig_callable(
File "/usr/local/google/home/kiuk/.pyenv/versions/venv310/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 472, in __call__
return _compile(
File "/usr/local/google/home/kiuk/.pyenv/versions/venv310/lib/python3.10/site-packages/torch/_utils_internal.py", line 84, in wrapper_function
return StrobelightCompileTimeProfiler.profile_compile_time(
File "/usr/local/google/home/kiuk/.pyenv/versions/venv310/lib/python3.10/site-packages/torch/_strobelight/compile_time_profiler.py", line 129, in profile_compile_time
return func(*args, **kwargs)
File "/usr/local/google/home/kiuk/.pyenv/versions/3.10.12/lib/python3.10/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/usr/local/google/home/kiuk/.pyenv/versions/venv310/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 817, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
File "/usr/local/google/home/kiuk/.pyenv/versions/venv310/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 231, in time_wrapper
r = func(*args, **kwargs)
File "/usr/local/google/home/kiuk/.pyenv/versions/venv310/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 636, in compile_inner
out_code = transform_code_object(code, transform)
File "/usr/local/google/home/kiuk/.pyenv/versions/venv310/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 1185, in transform_code_object
transformations(instructions, code_options)
File "/usr/local/google/home/kiuk/.pyenv/versions/venv310/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 178, in _fn
return fn(*args, **kwargs)
File "/usr/local/google/home/kiuk/.pyenv/versions/venv310/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 582, in transform
tracer.run()
File "/usr/local/google/home/kiuk/.pyenv/versions/venv310/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2451, in run
super().run()
File "/usr/local/google/home/kiuk/.pyenv/versions/venv310/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 893, in run
while self.step():
File "/usr/local/google/home/kiuk/.pyenv/versions/venv310/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 805, in step
self.dispatch_table[inst.opcode](self, inst)
File "/usr/local/google/home/kiuk/.pyenv/versions/venv310/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 499, in wrapper
return inner_fn(self, inst)
File "/usr/local/google/home/kiuk/.pyenv/versions/venv310/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1512, in CALL_FUNCTION_KW
self.call_function(fn, args, kwargs)
File "/usr/local/google/home/kiuk/.pyenv/versions/venv310/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 743, in call_function
self.push(fn.call_function(self, args, kwargs))
File "/usr/local/google/home/kiuk/.pyenv/versions/venv310/lib/python3.10/site-packages/torch/_dynamo/variables/user_defined.py", line 749, in call_function
return func_var.call_function(tx, [obj_var] + args, kwargs)
File "/usr/local/google/home/kiuk/.pyenv/versions/venv310/lib/python3.10/site-packages/torch/_dynamo/variables/user_defined.py", line 789, in call_function
return self.call_method(tx, "__call__", args, kwargs)
File "/usr/local/google/home/kiuk/.pyenv/versions/venv310/lib/python3.10/site-packages/torch/_dynamo/variables/user_defined.py", line 649, in call_method
return super().call_method(tx, name, args, kwargs)
File "/usr/local/google/home/kiuk/.pyenv/versions/venv310/lib/python3.10/site-packages/torch/_dynamo/variables/base.py", line 320, in call_method
unimplemented(f"call_method {self} {name} {args} {kwargs}")
File "/usr/local/google/home/kiuk/.pyenv/versions/venv310/lib/python3.10/site-packages/torch/_dynamo/exc.py", line 221, in unimplemented
raise Unsupported(msg)
torch._dynamo.exc.Unsupported: call_method UserDefinedObjectVariable(cython_function_or_method) __call__ [UserDefinedObjectVariable()] {'low': ConstantVariable(), 'high': ConstantVariable(), 'size': TupleVariable()}
from user code:
File "/usr/local/google/home/kiuk/tmp/numpy_repro.py", line 7, in f1
a = np.random.uniform(low=-1, high=1, size=(20, 1))
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
You can suppress this exception and fall back to eager by setting:
import torch._dynamo
torch._dynamo.config.suppress_errors = True
Minified repro
NOTE: no minifier output since the failure is during tracing, below is the output with TORCHLOGS turned on for dynamo
COMMAND:
$ TORCH_LOGS="+dynamo" TORCHDYNAMO_VERBOSE=1 python numpy_repro.py
OUTPUT
V0924 11:45:05.283000 139941291740032 torch/_dynamo/convert_frame.py:776] [0/0] torchdynamo start compiling f1 /usr/local/google/home/kiuk/tmp/numpy_repro.py:6, stack (elided 5 frames):
V0924 11:45:05.283000 139941291740032 torch/_dynamo/convert_frame.py:776] [0/0] File "/usr/local/google/home/kiuk/tmp/numpy_repro.py", line 11, in <module>
V0924 11:45:05.283000 139941291740032 torch/_dynamo/convert_frame.py:776] [0/0] result = optimize(f1)()
V0924 11:45:05.283000 139941291740032 torch/_dynamo/convert_frame.py:776] [0/0] File "/usr/local/google/home/kiuk/.pyenv/versions/venv310/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 433, in _fn
V0924 11:45:05.283000 139941291740032 torch/_dynamo/convert_frame.py:776] [0/0] return fn(*args, **kwargs)
V0924 11:45:05.283000 139941291740032 torch/_dynamo/convert_frame.py:776] [0/0] File "/usr/local/google/home/kiuk/.pyenv/versions/venv310/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 1116, in __call__
V0924 11:45:05.283000 139941291740032 torch/_dynamo/convert_frame.py:776] [0/0] return self._torchdynamo_orig_callable(
V0924 11:45:05.283000 139941291740032 torch/_dynamo/convert_frame.py:776] [0/0]
I0924 11:45:05.316000 139941291740032 torch/_dynamo/logging.py:56] [0/0] Step 1: torchdynamo start tracing f1 /usr/local/google/home/kiuk/tmp/numpy_repro.py:6
V0924 11:45:05.317000 139941291740032 torch/fx/experimental/symbolic_shapes.py:2529] [0/0] create_env
V0924 11:45:05.324000 139941291740032 torch/_dynamo/symbolic_convert.py:775] [0/0] [__trace_source] TRACE starts_line /usr/local/google/home/kiuk/tmp/numpy_repro.py:7 in f1 (f1)
V0924 11:45:05.324000 139941291740032 torch/_dynamo/symbolic_convert.py:775] [0/0] [__trace_source] a = np.random.uniform(low=-1, high=1, size=(20, 1))
V0924 11:45:05.325000 139941291740032 torch/_dynamo/symbolic_convert.py:798] [0/0] [__trace_bytecode] TRACE LOAD_GLOBAL np []
V0924 11:45:05.326000 139941291740032 torch/_dynamo/symbolic_convert.py:798] [0/0] [__trace_bytecode] TRACE LOAD_ATTR random [PythonModuleVariable(<module 'numpy' from '/usr/local/google/home/kiuk/.pyenv/versions/venv310/lib/python3.10/site-packages/numpy/__init__.py'>)]
V0924 11:45:05.327000 139941291740032 torch/_dynamo/symbolic_convert.py:798] [0/0] [__trace_bytecode] TRACE LOAD_ATTR uniform [PythonModuleVariable(<module 'numpy.random' from '/usr/local/google/home/kiuk/.pyenv/versions/venv310/lib/python3.10/site-packages/numpy/random/__init__.py'>)]
V0924 11:45:05.327000 139941291740032 torch/_dynamo/symbolic_convert.py:798] [0/0] [__trace_bytecode] TRACE LOAD_CONST -1 [UserDefinedObjectVariable()]
V0924 11:45:05.327000 139941291740032 torch/_dynamo/symbolic_convert.py:798] [0/0] [__trace_bytecode] TRACE LOAD_CONST 1 [UserDefinedObjectVariable(), ConstantVariable()]
V0924 11:45:05.327000 139941291740032 torch/_dynamo/symbolic_convert.py:798] [0/0] [__trace_bytecode] TRACE LOAD_CONST (20, 1) [UserDefinedObjectVariable(), ConstantVariable(), ConstantVariable()]
V0924 11:45:05.327000 139941291740032 torch/_dynamo/symbolic_convert.py:798] [0/0] [__trace_bytecode] TRACE LOAD_CONST ('low', 'high', 'size') [UserDefinedObjectVariable(), ConstantVariable(), ConstantVariable(), TupleVariable()]
V0924 11:45:05.328000 139941291740032 torch/_dynamo/symbolic_convert.py:798] [0/0] [__trace_bytecode] TRACE CALL_FUNCTION_KW 3 [UserDefinedObjectVariable(), ConstantVariable(), ConstantVariable(), TupleVariable(), TupleVariable()]
V0924 11:45:05.328000 139941291740032 torch/_dynamo/symbolic_convert.py:814] [0/0] empty checkpoint
Traceback (most recent call last):
File "/usr/local/google/home/kiuk/tmp/numpy_repro.py", line 11, in <module>
result = optimize(f1)()
File "/usr/local/google/home/kiuk/.pyenv/versions/venv310/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 433, in _fn
return fn(*args, **kwargs)
File "/usr/local/google/home/kiuk/.pyenv/versions/venv310/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 1116, in __call__
return self._torchdynamo_orig_callable(
File "/usr/local/google/home/kiuk/.pyenv/versions/venv310/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 472, in __call__
return _compile(
File "/usr/local/google/home/kiuk/.pyenv/versions/venv310/lib/python3.10/site-packages/torch/_utils_internal.py", line 84, in wrapper_function
return StrobelightCompileTimeProfiler.profile_compile_time(
File "/usr/local/google/home/kiuk/.pyenv/versions/venv310/lib/python3.10/site-packages/torch/_strobelight/compile_time_profiler.py", line 129, in profile_compile_time
return func(*args, **kwargs)
File "/usr/local/google/home/kiuk/.pyenv/versions/3.10.12/lib/python3.10/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/usr/local/google/home/kiuk/.pyenv/versions/venv310/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 817, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
File "/usr/local/google/home/kiuk/.pyenv/versions/venv310/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 231, in time_wrapper
r = func(*args, **kwargs)
File "/usr/local/google/home/kiuk/.pyenv/versions/venv310/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 636, in compile_inner
out_code = transform_code_object(code, transform)
File "/usr/local/google/home/kiuk/.pyenv/versions/venv310/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 1185, in transform_code_object
transformations(instructions, code_options)
File "/usr/local/google/home/kiuk/.pyenv/versions/venv310/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 178, in _fn
return fn(*args, **kwargs)
File "/usr/local/google/home/kiuk/.pyenv/versions/venv310/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 582, in transform
tracer.run()
File "/usr/local/google/home/kiuk/.pyenv/versions/venv310/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2451, in run
super().run()
File "/usr/local/google/home/kiuk/.pyenv/versions/venv310/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 893, in run
while self.step():
File "/usr/local/google/home/kiuk/.pyenv/versions/venv310/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 805, in step
self.dispatch_table[inst.opcode](self, inst)
File "/usr/local/google/home/kiuk/.pyenv/versions/venv310/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 499, in wrapper
return inner_fn(self, inst)
File "/usr/local/google/home/kiuk/.pyenv/versions/venv310/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1512, in CALL_FUNCTION_KW
self.call_function(fn, args, kwargs)
File "/usr/local/google/home/kiuk/.pyenv/versions/venv310/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 743, in call_function
self.push(fn.call_function(self, args, kwargs))
File "/usr/local/google/home/kiuk/.pyenv/versions/venv310/lib/python3.10/site-packages/torch/_dynamo/variables/user_defined.py", line 749, in call_function
return func_var.call_function(tx, [obj_var] + args, kwargs)
File "/usr/local/google/home/kiuk/.pyenv/versions/venv310/lib/python3.10/site-packages/torch/_dynamo/variables/user_defined.py", line 789, in call_function
return self.call_method(tx, "__call__", args, kwargs)
File "/usr/local/google/home/kiuk/.pyenv/versions/venv310/lib/python3.10/site-packages/torch/_dynamo/variables/user_defined.py", line 649, in call_method
return super().call_method(tx, name, args, kwargs)
File "/usr/local/google/home/kiuk/.pyenv/versions/venv310/lib/python3.10/site-packages/torch/_dynamo/variables/base.py", line 320, in call_method
unimplemented(f"call_method {self} {name} {args} {kwargs}")
File "/usr/local/google/home/kiuk/.pyenv/versions/venv310/lib/python3.10/site-packages/torch/_dynamo/exc.py", line 221, in unimplemented
raise Unsupported(msg)
torch._dynamo.exc.Unsupported: call_method UserDefinedObjectVariable(cython_function_or_method) __call__ [UserDefinedObjectVariable()] {'low': ConstantVariable(), 'high': ConstantVariable(), 'size': TupleVariable()}
from user code:
File "/usr/local/google/home/kiuk/tmp/numpy_repro.py", line 7, in f1
a = np.random.uniform(low=-1, high=1, size=(20, 1))
You can suppress this exception and fall back to eager by setting:
import torch._dynamo
torch._dynamo.config.suppress_errors = True
I0924 11:45:05.336000 139941291740032 torch/_dynamo/utils.py:335] TorchDynamo compilation metrics:
I0924 11:45:05.336000 139941291740032 torch/_dynamo/utils.py:335] Function, Runtimes (s)
I0924 11:45:05.336000 139941291740032 torch/_dynamo/utils.py:335] _compile.<locals>.compile_inner, 0.0000
V0924 11:45:05.336000 139941291740032 torch/fx/experimental/symbolic_shapes.py:116] lru_cache_stats constrain_symbol_range: CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)
V0924 11:45:05.336000 139941291740032 torch/fx/experimental/symbolic_shapes.py:116] lru_cache_stats evaluate_expr: CacheInfo(hits=0, misses=0, maxsize=256, currsize=0)
V0924 11:45:05.336000 139941291740032 torch/fx/experimental/symbolic_shapes.py:116] lru_cache_stats _simplify_floor_div: CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)
V0924 11:45:05.337000 139941291740032 torch/fx/experimental/symbolic_shapes.py:116] lru_cache_stats _maybe_guard_rel: CacheInfo(hits=0, misses=0, maxsize=256, currsize=0)
V0924 11:45:05.337000 139941291740032 torch/fx/experimental/symbolic_shapes.py:116] lru_cache_stats _find: CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)
V0924 11:45:05.337000 139941291740032 torch/fx/experimental/symbolic_shapes.py:116] lru_cache_stats has_hint: CacheInfo(hits=0, misses=0, maxsize=256, currsize=0)
V0924 11:45:05.337000 139941291740032 torch/fx/experimental/symbolic_shapes.py:116] lru_cache_stats size_hint: CacheInfo(hits=0, misses=0, maxsize=256, currsize=0)
V0924 11:45:05.337000 139941291740032 torch/fx/experimental/symbolic_shapes.py:116] lru_cache_stats simplify: CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)
V0924 11:45:05.337000 139941291740032 torch/fx/experimental/symbolic_shapes.py:116] lru_cache_stats _update_divisible: CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)
V0924 11:45:05.337000 139941291740032 torch/fx/experimental/symbolic_shapes.py:116] lru_cache_stats replace: CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)
V0924 11:45:05.337000 139941291740032 torch/fx/experimental/symbolic_shapes.py:116] lru_cache_stats _maybe_evaluate_static: CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)
V0924 11:45:05.337000 139941291740032 torch/fx/experimental/symbolic_shapes.py:116] lru_cache_stats get_implications: CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)
V0924 11:45:05.337000 139941291740032 torch/fx/experimental/symbolic_shapes.py:116] lru_cache_stats get_axioms: CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)
V0924 11:45:05.337000 139941291740032 torch/fx/experimental/symbolic_shapes.py:116] lru_cache_stats safe_expand: CacheInfo(hits=0, misses=0, maxsize=256, currsize=0)
V0924 11:45:05.337000 139941291740032 torch/fx/experimental/symbolic_shapes.py:116] lru_cache_stats uninteresting_files: CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)
(venv310) [email protected]% I0924 11:45:06.530000 139676392905600 torch/_dynamo/utils.py:335] TorchDynamo compilation metrics: ~/tmp
I0924 11:45:06.530000 139676392905600 torch/_dynamo/utils.py:335] Function, Runtimes (s)
V0924 11:45:06.530000 139676392905600 torch/fx/experimental/symbolic_shapes.py:116] lru_cache_stats constrain_symbol_range: CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)
Versions
Collecting environment information...
PyTorch version: 2.4.1+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A
OS: Debian GNU/Linux rodete (x86_64)
GCC version: (Debian 13.2.0-13) 13.2.0
Clang version: 16.0.6 (26)
CMake version: version 3.29.2
Libc version: glibc-2.38
Python version: 3.10.12 (main, Apr 16 2024, 16:17:06) [GCC 13.2.0] (64-bit runtime)
Python platform: Linux-6.9.10-1rodete4-amd64-x86_64-with-glibc2.38
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA RTX A5000
Nvidia driver version: 535.183.01
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 43 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 24
On-line CPU(s) list: 0-23
Vendor ID: AuthenticAMD
Model name: AMD Ryzen Threadripper PRO 3945WX 12-Cores
CPU family: 23
Model: 49
Thread(s) per core: 2
Core(s) per socket: 12
Socket(s): 1
Stepping: 0
Frequency boost: enabled
CPU(s) scaling MHz: 86%
CPU max MHz: 4425.7808
CPU min MHz: 2200.0000
BogoMIPS: 7985.02
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif v_spec_ctrl umip rdpid overflow_recov succor smca sev sev_es
Virtualization: AMD-V
L1d cache: 384 KiB (12 instances)
L1i cache: 384 KiB (12 instances)
L2 cache: 6 MiB (12 instances)
L3 cache: 64 MiB (4 instances)
NUMA node(s): 1
NUMA node0 CPU(s): 0-23
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed: Mitigation; untrained return thunk; SMT enabled with STIBP protection
Vulnerability Spec rstack overflow: Mitigation; Safe RET
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Retpolines; IBPB conditional; STIBP always-on; RSB filling; PBRSB-eIBRS Not affected; BHI Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Versions of relevant libraries:
[pip3] flake8==6.1.0
[pip3] flake8-bugbear==23.3.23
[pip3] flake8-comprehensions==3.15.0
[pip3] flake8-executable==2.1.3
[pip3] flake8-logging-format==0.9.0
[pip3] flake8-pyi==23.3.1
[pip3] flake8-simplify==0.19.3
[pip3] mypy==1.10.0
[pip3] mypy-extensions==1.0.0
[pip3] numpy==2.0.0
[pip3] optree==0.12.1
[pip3] pytorch-triton==3.0.0+989adb9a29
[pip3] torch==2.4.1
[pip3] torch-xla==2.4.0
[pip3] triton==3.0.0
[conda] magma-cuda110 2.5.2 1 pytorch
[conda] mkl-include 2024.1.0 intel_691 intel
[conda] mkl-static 2024.1.0 intel_691 intel
cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @mruberry @rgommers @chauhang @penguinwu, @atalman, @qlzh727, @haifeng-jin
Metadata
Metadata
Assignees
Labels
high prioritymodule: numpyRelated to numpy support, and also numpy compatibility of our operatorsRelated to numpy support, and also numpy compatibility of our operatorsoncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module