Arrays typing annotations#
API#
Function inputs & outputs can be annotated to help the reader better understand intended shape/dtype.
from etils.array_types import Array, FloatArray, f32, ui8
def _normalize_image(img: ui8['h w c']) -> f32['h w c']:
return np.interp(img, from_=(0, 255), to=(-1, 1))
This indicates the reader that the function takes a 3d uint8 array and return a 3d float32 with the same shape values.
Note: Those typing annotations are not (yet) detected by static type checking tools. However, they are already helpful as documentation.
Annotation conventions#
Typing annotations shape follow the conventions:
Valid symbols:
str: Named axis (e.g.f32['batch height width'])int: Static axis (e.g.f32[28, 28],f32['h w 3'])_: Anonymous axis (e.g.f32['batch _ _ c'],f32[None, 3])...: Anonymous zeros or more axis (e.g.f32['... h w c'],f32[..., 3])*name: Named zeros or more axis (e.g.f32['*batch_dims h w c'])+,-,/,*operators (e.g.f32['h/2 w/2 c1+c2'])
Typing annotations are only considered to be consistent per function call, so a function
f32['h w'] -> f32['h w']can be called twice with 2 different image sizes.Passing multiple values is the same as concatenating the string (e.g.
f32[..., 'h', 'w', 3] == f32['... h w 3']DType can be:
Array[...]: Any dtype acceptedFloatArray(acceptsf32,bf16, …),IntArray(acceptsui8,i32,i64, …): Respectively accept an union of multiple typesf32,ui8, …: Specific type
ArrayLike[f32[...]]indicates any array convertible values are accepted (list,tuple, …).
Runtime shape/dtype checking#
You can decorate your function with @enp.check_and_normalize_arrays so that
array shape/dtype are dynamically validated at runtime:
from etils import enp
from etils.array_types import FloatArray, IntArray
@enp.check_and_normalize_arrays
def add(x: IntArray, y: IntArray) -> IntArray:
return x + y
TF / Jax / Numpy compatibility#
Functions decorated with enp.check_and_normalize_arrays support np, jnp,
and tnp:
If args are mixed between
jnpandtnp, an error is raisedIf args are
xnpwithnp, thenparray is auto-casted toxnp.You can force usage of TF / Jax / Numpy by passing a
xnp=kwargs (automatically added).
add(np.array(1), jnp.array(2)) # np auto-casted to jnp
add(tf.constant(1), jnp.array(2)) # Error jnp / TF conflict
add(tf.constant(1), jnp.array(2), xnp=jnp) # Force jnp usage
Using strict=False makes your function auto-convert list, int,… to
xnp.ndarray:
@enp.check_and_normalize_arrays(strict=False)
def add(x: IntArray, y: IntArray):
return x + y
add([1, 2, 3], 10) # == np.array([10, 12, 13])
add([1, 2, 3], 10, xnp=jnp) # == jnp.array([10, 12, 13])
add([1, 2, 3], tf.constant(10)) # == tnp.array([10, 12, 13])
You can add a xnp: enp.NpModule = ... kwarg to your function which will be
automatically assigned to the auto-infered xnp:
@enp.check_and_normalize_arrays(strict=False)
def add(x: IntArray, y: IntArray, *, xnp: enp.NpModule = ...):
return xnp.add(x, y)
add(1, [1, 2, 3]) # Inside the function, `xnp=np`
add(tf.constant(1), tf.constant(2)) # Inside the function, `xnp=tnp`
DType checking#
There are 2 levels of checking:
Using type union:
IntArray(acceptsui8,i32,i64, …),FloatArray(acceptsf32,bf16, …)Using specific type:
f32,ui8, …
Using type unions allows your functions to support quantization, …
Shape checking#
Currently, shape checking is not yet supported (but in project).