I am trying to run this repo on my Jeton Orin AGX 64GB developer kit: GitHub - vincekurtz/hydrax: Sampling-based model predictive control on GPU with JAX/MJX. It runs MuJoCo MPC on the GPU using Jax. However, when trying to run this I get a load of errors, listed below. The failed import of warp is not an issue, since after installing this I get the same error, and the repo works on other x86_64 based devices with the failed import. After a lot of bug-testing, I landed on trying to build Jax wheels locally, but when trying to build jax==0.5.0 in the conda environment I created for this repo, I get data type errors where it expects a uint8 but receives a svuint8. I don’t know how to continue from here to get jax to work with cuda support. I also included some diagnostics below. I am running Jetpack version 6.2.1 and Ubuntu 22.04, since I need ros2 Humble.
Failed to import warp: No module named 'warp'
Failed to import mujoco_warp: No module named 'warp'
E1120 11:39:58.116684 14373 cuda_executor.cc:1784] Nvml call failed with 3(Not Supported). Assuming PCIe gen 3 x16 bandwidth.
W1120 11:39:58.116796 14373 cuda_executor.cc:1802] GPU interconnect information not available: INTERNAL: NVML library doesn't have required functions.
E1120 11:39:58.121420 14323 cuda_executor.cc:1784] Nvml call failed with 3(Not Supported). Assuming PCIe gen 3 x16 bandwidth.
W1120 11:39:58.121453 14323 cuda_executor.cc:1802] GPU interconnect information not available: INTERNAL: NVML library doesn't have required functions.
E1120 11:39:58.131061 14323 cuda_dnn.cc:456] Loaded runtime CuDNN library: 9.3.0 but source was compiled with: 9.8.0. CuDNN library needs to have matching major version and equal or higher minor version. If using a binary install, upgrade your CuDNN library. If building from sources, make sure the library loaded at runtime is compatible with the version specified during compile configuration.
E1120 11:39:58.132785 14323 cuda_dnn.cc:456] Loaded runtime CuDNN library: 9.3.0 but source was compiled with: 9.8.0. CuDNN library needs to have matching major version and equal or higher minor version. If using a binary install, upgrade your CuDNN library. If building from sources, make sure the library loaded at runtime is compatible with the version specified during compile configuration.
E1120 11:40:13.433815 14323 cuda_dnn.cc:456] Loaded runtime CuDNN library: 9.3.0 but source was compiled with: 9.8.0. CuDNN library needs to have matching major version and equal or higher minor version. If using a binary install, upgrade your CuDNN library. If building from sources, make sure the library loaded at runtime is compatible with the version specified during compile configuration.
E1120 11:40:13.449384 14323 cuda_dnn.cc:456] Loaded runtime CuDNN library: 9.3.0 but source was compiled with: 9.8.0. CuDNN library needs to have matching major version and equal or higher minor version. If using a binary install, upgrade your CuDNN library. If building from sources, make sure the library loaded at runtime is compatible with the version specified during compile configuration.
Traceback (most recent call last):
File "/home/jetson/hydrax/examples/cart_pole.py", line 14, in <module>
task = CartPole()
^^^^^^^^^^
File "/home/jetson/hydrax/hydrax/tasks/cart_pole.py", line 18, in __init__
super().__init__(mj_model, trace_sites=["tip"])
File "/home/jetson/hydrax/hydrax/task_base.py", line 38, in __init__
self.model = mjx.put_model(mj_model)
^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jetson/miniconda3/envs/hydrax/lib/python3.12/site-packages/mujoco/mjx/_src/io.py", line 513, in put_model
return _put_model_jax(m, device)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jetson/miniconda3/envs/hydrax/lib/python3.12/site-packages/mujoco/mjx/_src/io.py", line 421, in _put_model_jax
return _strip_weak_type(model)
^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jetson/miniconda3/envs/hydrax/lib/python3.12/site-packages/mujoco/mjx/_src/io.py", line 166, in _strip_weak_type
return jax.tree_util.tree_map(f, tree)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jetson/miniconda3/envs/hydrax/lib/python3.12/site-packages/jax/_src/tree_util.py", line 361, in tree_map
return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jetson/miniconda3/envs/hydrax/lib/python3.12/site-packages/jax/_src/tree_util.py", line 361, in <genexpr>
return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
^^^^^^
File "/home/jetson/miniconda3/envs/hydrax/lib/python3.12/site-packages/mujoco/mjx/_src/io.py", line 163, in f
return leaf.astype(jax.dtypes.canonicalize_dtype(leaf.dtype))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jetson/miniconda3/envs/hydrax/lib/python3.12/site-packages/jax/_src/numpy/array_methods.py", line 125, in _astype
return lax_numpy.astype(self, dtype, copy=copy, device=device)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jetson/miniconda3/envs/hydrax/lib/python3.12/site-packages/jax/_src/numpy/lax_numpy.py", line 5350, in astype
result = lax._convert_element_type(
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jetson/miniconda3/envs/hydrax/lib/python3.12/site-packages/jax/_src/lax/lax.py", line 1727, in _convert_element_type
return convert_element_type_p.bind(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jetson/miniconda3/envs/hydrax/lib/python3.12/site-packages/jax/_src/core.py", line 632, in bind
return self._true_bind(*args, **params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jetson/miniconda3/envs/hydrax/lib/python3.12/site-packages/jax/_src/core.py", line 648, in _true_bind
return self.bind_with_trace(prev_trace, args, params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jetson/miniconda3/envs/hydrax/lib/python3.12/site-packages/jax/_src/lax/lax.py", line 5035, in _convert_element_type_bind_with_trace
operand = core.Primitive.bind_with_trace(convert_element_type_p, trace, args, params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jetson/miniconda3/envs/hydrax/lib/python3.12/site-packages/jax/_src/core.py", line 660, in bind_with_trace
return trace.process_primitive(self, args, params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jetson/miniconda3/envs/hydrax/lib/python3.12/site-packages/jax/_src/core.py", line 1189, in process_primitive
return primitive.impl(*args, **params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jetson/miniconda3/envs/hydrax/lib/python3.12/site-packages/jax/_src/dispatch.py", line 92, in apply_primitive
outs = fun(*args)
^^^^^^^^^^
jax.errors.JaxRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
(hydrax) jetson@ubuntu:~/hydrax$ pip list | grep nvidia
nvidia-cublas-cu12 12.9.1.4
nvidia-cuda-cupti-cu12 12.9.79
nvidia-cuda-nvcc-cu12 12.9.86
nvidia-cuda-nvrtc-cu12 12.9.86
nvidia-cuda-runtime-cu12 12.9.79
nvidia-cudnn-cu12 9.16.0.29
nvidia-cufft-cu12 11.4.1.4
nvidia-cusolver-cu12 11.7.5.82
nvidia-cusparse-cu12 12.5.10.65
nvidia-nccl-cu12 2.28.9
nvidia-nvjitlink-cu12 12.9.86
nvidia-nvshmem-cu12 3.4.5
(hydrax) jetson@ubuntu:~/hydrax$ conda list | grep cudnn
conda list | grep cudatoolkit
cudnn 9.10.2.21 h32c1c63_0 conda-forge
libcudnn 9.10.2.21 hd88968f_0 conda-forge
libcudnn-dev 9.10.2.21 h76cf850_0 conda-forge
nvidia-cudnn-cu12 9.16.0.29 pypi_0 pypi
(hydrax) jetson@ubuntu:~/hydrax$ cat /proc/driver/nvidia/version
NVRM version: NVIDIA UNIX Open Kernel Module for aarch64 540.4.0 Release Build (buildbrain@mobile-u64-6354-d6000) Thu Sep 18 15:33:49 PDT 2025
GCC version: collect2: error: ld returned 1 exit status
(hydrax) jetson@ubuntu:~/hydrax$ nvcc --version
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2023 NVIDIA Corporation
Built on Fri_Sep__8_19:17:54_PDT_2023
Cuda compilation tools, release 12.3, V12.3.52
Build cuda_12.3.r12.3/compiler.33281558_0
(hydrax) jetson@ubuntu:~/hydrax$ nvidia-smi
Thu Nov 20 11:55:52 2025
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 540.4.0 Driver Version: 540.4.0 CUDA Version: 12.6 |
|-----------------------------------------+----------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+======================+======================|
| 0 Orin (nvgpu) N/A | N/A N/A | N/A |
| N/A N/A N/A N/A / N/A | Not Supported | N/A N/A |
| | | N/A |
+-----------------------------------------+----------------------+----------------------+
+---------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=======================================================================================|
| No running processes found |
+---------------------------------------------------------------------------------------+
(hydrax) jetson@ubuntu:~/hydrax$ cat /usr/include/cudnn_version.h | grep CUDNN_MAJOR -A 2
#define CUDNN_MAJOR 9
#define CUDNN_MINOR 3
#define CUDNN_PATCHLEVEL 0
--
#define CUDNN_VERSION (CUDNN_MAJOR * 10000 + CUDNN_MINOR * 100 + CUDNN_PATCHLEVEL)
/* cannot use constexpr here since this is a C-only file */