jax.numpy.fromfunction

jax.numpy.fromfunction#

jax.numpy.fromfunction(function, shape, *, dtype=<class 'float'>, **kwargs)[source]#

通过对每个坐标执行函数来构建数组。

LAX 后端实现的 numpy.fromfunction().

原始文档字符串如下。

因此,结果数组在坐标 (x, y, z) 处具有值 fn(x, y, z)

参数:
  • function (callable) – 该函数使用 N 个参数调用,其中 N 是 shape 的秩。每个参数表示沿特定轴变化的数组的坐标。例如,如果 shape(2, 2),则参数将为 array([[0, 0], [1, 1]])array([[0, 1], [0, 1]])

  • shape ((N,) tuple of ints) – 输出数组的形状,这也决定了传递给 function 的坐标数组的形状。

  • dtype (data-type, optional) – 传递给 function 的坐标数组的数据类型。默认情况下,dtype 为浮点数。

返回值:

fromfunction – 对 function 的调用结果直接传递回。因此,fromfunction 的形状完全由 function 决定。如果 function 返回一个标量值,则 fromfunction 的形状将不匹配 shape 参数。

返回类型:

任意