cr.sparse.promote_arg_dtypes¶
- cr.sparse.promote_arg_dtypes(*args)[source]¶
Promotes args to a common inexact type.
- Parameters
*args – list of JAX ndarrays to be promoted to common inexact type
- Returns
The same list of arrays with their dtype promoted to a common inexact type
Example
Promoting a single argument:
>>> cr.sparse.promote_arg_dtypes(jnp.arange(2)) DeviceArray([0., 1.], dtype=float32) >>> from jax.config import config >>> config.update("jax_enable_x64", True) >>> cr.sparse.promote_arg_dtypes(jnp.arange(2)) DeviceArray([0., 1.], dtype=float64)
Promoting two arguments to common floating point type:
>>> a = jnp.arange(2) >>> b = jnp.arange(1.5, 3.5) >>> a, b = cr.sparse.promote_arg_dtypes(a, b) >>> print(a) >>> print(b) [0. 1.] [1.5 2.5]
A mix of real and complex types:
>>> a = jnp.arange(2) + 0.j >>> b = jnp.arange(1.5, 3.5) >>> a, b = cr.sparse.promote_arg_dtypes(a, b) >>> print(a) >>> print(b) [0.+0.j 1.+0.j] [1.5+0.j 2.5+0.j]