jax.numpy.fromfunction#

jax.numpy.fromfunction(function, shape, *, dtype=<class 'float'>, **kwargs)[源代码]#

从应用于索引的函数创建数组。

numpy.fromfunction() 的 JAX 实现。JAX 实现的不同之处在于它通过 jax.vmap() 进行分派,因此与 NumPy 不同,该函数在逻辑上对标量输入进行操作,并且不需要显式处理广播输入(请参阅下面的示例)。

参数:
  • function (Callable[..., Array]) – 一个接受 N 个动态标量并输出标量的函数。

  • shape (Any) – 一个长度为 N 的整数元组,指定输出形状。

  • dtype (DTypeLike) – 可选地指定输入的 dtype。默认为浮点数。

  • kwargs – 额外的关键字参数静态传递给 function

返回:

如果 function 返回标量,则返回形状为 shape 的数组,或者通常是由 function 的输出决定的、前导维度为 shape 的数组的 pytree。

返回类型:

数组

另请参阅

示例

生成给定形状的乘法表

>>> jnp.fromfunction(jnp.multiply, shape=(3, 6), dtype=int)
Array([[ 0,  0,  0,  0,  0,  0],
       [ 0,  1,  2,  3,  4,  5],
       [ 0,  2,  4,  6,  8, 10]], dtype=int32)

function 返回非标量时,输出将具有 shape 的前导维度

>>> def f(x):
...   return (x + 1) * jnp.arange(3)
>>> jnp.fromfunction(f, shape=(2,))
Array([[0., 1., 2.],
       [0., 2., 4.]], dtype=float32)

function 可能会返回多个结果,在这种情况下,每个结果都会被独立映射

>>> def f(x, y):
...   return x + y, x * y
>>> x_plus_y, x_times_y = jnp.fromfunction(f, shape=(3, 5))
>>> print(x_plus_y)
[[0. 1. 2. 3. 4.]
 [1. 2. 3. 4. 5.]
 [2. 3. 4. 5. 6.]]
>>> print(x_times_y)
[[0. 0. 0. 0. 0.]
 [0. 1. 2. 3. 4.]
 [0. 2. 4. 6. 8.]]

JAX 实现与 NumPy 的实现略有不同。在 numpy.fromfunction() 中,该函数应显式地对输入值的完整网格逐元素进行操作

>>> def f(x, y):
...   print(f"{x.shape = }\n{y.shape = }")
...   return x + y
...
>>> np.fromfunction(f, (2, 3))
x.shape = (2, 3)
y.shape = (2, 3)
array([[0., 1., 2.],
       [1., 2., 3.]])

jax.numpy.fromfunction() 中,该函数通过 jax.vmap() 进行向量化,因此应在标量值上进行操作

>>> jnp.fromfunction(f, (2, 3))
x.shape = ()
y.shape = ()
Array([[0., 1., 2.],
       [1., 2., 3.]], dtype=float32)