cr.sparse.promote_arg_dtypes

cr.sparse.promote_arg_dtypes(*args)[source]

Promotes args to a common inexact type.

Parameters

*args – list of JAX ndarrays to be promoted to common inexact type

Returns

The same list of arrays with their dtype promoted to a common inexact type

Example

Promoting a single argument:

>>> cr.sparse.promote_arg_dtypes(jnp.arange(2))
DeviceArray([0., 1.], dtype=float32)
>>> from jax.config import config
>>> config.update("jax_enable_x64", True)
>>> cr.sparse.promote_arg_dtypes(jnp.arange(2))
DeviceArray([0., 1.], dtype=float64)

Promoting two arguments to common floating point type:

>>> a = jnp.arange(2)
>>> b = jnp.arange(1.5, 3.5)
>>> a, b = cr.sparse.promote_arg_dtypes(a, b)
>>> print(a)
>>> print(b)
[0. 1.]
[1.5 2.5]

A mix of real and complex types:

>>> a = jnp.arange(2) + 0.j
>>> b = jnp.arange(1.5, 3.5)
>>> a, b = cr.sparse.promote_arg_dtypes(a, b)
>>> print(a)
>>> print(b)
[0.+0.j 1.+0.j]
[1.5+0.j 2.5+0.j]