jax.numpy.fromfunction#
- jax.numpy.fromfunction(function, shape, *, dtype=<class 'float'>, **kwargs)[source]#
Create an array from a function applied over indices.
JAX implementation of
numpy.fromfunction()
. The JAX implementation differs in that it dispatches viajax.vmap()
, and so unlike in NumPy the function logically operates on scalar inputs, and need not explicitly handle broadcasted inputs (See Examples below).- Parameters:
function (Callable[..., Array]) – a function that takes N dynamic scalars and outputs a scalar.
shape (Any) – a length-N tuple of integers specifying the output shape.
dtype (DTypeLike) – optionally specify the dtype of the inputs. Defaults to floating-point.
kwargs – additional keyword arguments are passed statically to
function
.
- Returns:
An array of shape
shape
iffunction
returns a scalar, or in general a pytree of arrays with leading dimensionsshape
, as determined by the output offunction
.- Return type:
See also
jax.vmap()
: the core transformation that thefromfunction()
API is built on.
Examples
Generate a multiplication table of a given shape:
>>> jnp.fromfunction(jnp.multiply, shape=(3, 6), dtype=int) Array([[ 0, 0, 0, 0, 0, 0], [ 0, 1, 2, 3, 4, 5], [ 0, 2, 4, 6, 8, 10]], dtype=int32)
When
function
returns a non-scalar the output will have leading dimension ofshape
:>>> def f(x): ... return (x + 1) * jnp.arange(3) >>> jnp.fromfunction(f, shape=(2,)) Array([[0., 1., 2.], [0., 2., 4.]], dtype=float32)
function
may return multiple results, in which case each is mapped independently:>>> def f(x, y): ... return x + y, x * y >>> x_plus_y, x_times_y = jnp.fromfunction(f, shape=(3, 5)) >>> print(x_plus_y) [[0. 1. 2. 3. 4.] [1. 2. 3. 4. 5.] [2. 3. 4. 5. 6.]] >>> print(x_times_y) [[0. 0. 0. 0. 0.] [0. 1. 2. 3. 4.] [0. 2. 4. 6. 8.]]
The JAX implementation differs slightly from NumPy’s implementation. In
numpy.fromfunction()
, the function is expected to explicitly operate element-wise on the full grid of input values:>>> def f(x, y): ... print(f"{x.shape = }\n{y.shape = }") ... return x + y ... >>> np.fromfunction(f, (2, 3)) x.shape = (2, 3) y.shape = (2, 3) array([[0., 1., 2.], [1., 2., 3.]])
In
jax.numpy.fromfunction()
, the function is vectorized viajax.vmap()
, and so is expected to operate on scalar values:>>> jnp.fromfunction(f, (2, 3)) x.shape = () y.shape = () Array([[0., 1., 2.], [1., 2., 3.]], dtype=float32)