Description
JAX arrays are missing the .device attribute when running inside jax.jit, and accessing the .devices() method raises a concretization error:
>>> import jax, jax.numpy as jnp
>>> x = jnp.asarray(0)
>>> def f(x): return jnp.zeros(0, device=x.device)
>>> f(x)
Array([], shape=(0,), dtype=float32)
>>> jax.jit(f)(x)
AttributeError: DynamicJaxprTracer has no attribute device
>>> def g(x): return jnp.zeros(0, device=next(iter(x.devices())))
>>> jax.jit(g)(x)
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[]
The devices() method was called on traced array with shape int32[].
The error occurred while tracing the function g at <ipython-input-6-5e654833a7fe>:1 for jit. This concrete value was not available in Python because it depends on the value of the argument x.
The error occurred while tracing the function g at <ipython-input-6-5e654833a7fe>:1 for jit. This concrete value was not available in Python because it depends on the value of the argument x.
This breaks Array API compatibility and hinders array API compliant libraries that use the pattern
def f(x):
xp = array_namespace(x)
return xp.asarray(123, device=x.device)
You can see two such use cases in array-api-extra: https://github.com/search?q=repo%3Adata-apis%2Farray-api-extra+device%3D_compat.device&type=code
Workaround
I'm provisionally implementing a workaround in array-api-compat that causes device(x) to return None and to_device(x, device) to accept None to work around this issue. This will however produce outputs on the wrong device when x is not on the default device.
System info (python version, jaxlib version, accelerator, etc.)
JAX 0.4.35
Description
JAX arrays are missing the
.deviceattribute when running insidejax.jit, and accessing the.devices()method raises a concretization error:This breaks Array API compatibility and hinders array API compliant libraries that use the pattern
You can see two such use cases in array-api-extra: https://github.com/search?q=repo%3Adata-apis%2Farray-api-extra+device%3D_compat.device&type=code
Workaround
I'm provisionally implementing a workaround in
array-api-compatthat causesdevice(x)to return None andto_device(x, device)to accept None to work around this issue. This will however produce outputs on the wrong device when x is not on the default device.System info (python version, jaxlib version, accelerator, etc.)
JAX 0.4.35