jax.numpy.fromfunction#
- jax.numpy.fromfunction(function, shape, *, dtype=<class 'float'>, **kwargs)[源代码]#
从应用于索引的函数创建数组。
numpy.fromfunction()
的 JAX 实现。JAX 实现的不同之处在于它通过jax.vmap()
进行分派,因此与 NumPy 不同,该函数在逻辑上对标量输入进行操作,并且不需要显式处理广播输入(请参阅下面的示例)。- 参数:
function (Callable[..., Array]) – 一个接受 N 个动态标量并输出标量的函数。
shape (Any) – 一个长度为 N 的整数元组,指定输出形状。
dtype (DTypeLike) – 可选地指定输入的 dtype。默认为浮点数。
kwargs – 额外的关键字参数静态传递给
function
。
- 返回:
如果
function
返回标量,则返回形状为shape
的数组,或者通常是由function
的输出决定的、前导维度为shape
的数组的 pytree。- 返回类型:
另请参阅
jax.vmap()
:fromfunction()
API 构建于其上的核心转换。
示例
生成给定形状的乘法表
>>> jnp.fromfunction(jnp.multiply, shape=(3, 6), dtype=int) Array([[ 0, 0, 0, 0, 0, 0], [ 0, 1, 2, 3, 4, 5], [ 0, 2, 4, 6, 8, 10]], dtype=int32)
当
function
返回非标量时,输出将具有shape
的前导维度>>> def f(x): ... return (x + 1) * jnp.arange(3) >>> jnp.fromfunction(f, shape=(2,)) Array([[0., 1., 2.], [0., 2., 4.]], dtype=float32)
function
可能会返回多个结果,在这种情况下,每个结果都会被独立映射>>> def f(x, y): ... return x + y, x * y >>> x_plus_y, x_times_y = jnp.fromfunction(f, shape=(3, 5)) >>> print(x_plus_y) [[0. 1. 2. 3. 4.] [1. 2. 3. 4. 5.] [2. 3. 4. 5. 6.]] >>> print(x_times_y) [[0. 0. 0. 0. 0.] [0. 1. 2. 3. 4.] [0. 2. 4. 6. 8.]]
JAX 实现与 NumPy 的实现略有不同。在
numpy.fromfunction()
中,该函数应显式地对输入值的完整网格逐元素进行操作>>> def f(x, y): ... print(f"{x.shape = }\n{y.shape = }") ... return x + y ... >>> np.fromfunction(f, (2, 3)) x.shape = (2, 3) y.shape = (2, 3) array([[0., 1., 2.], [1., 2., 3.]])
在
jax.numpy.fromfunction()
中,该函数通过jax.vmap()
进行向量化,因此应在标量值上进行操作>>> jnp.fromfunction(f, (2, 3)) x.shape = () y.shape = () Array([[0., 1., 2.], [1., 2., 3.]], dtype=float32)