jax.numpy.frompyfunc#

jax.numpy.frompyfunc(func, /, nin, nout, *, identity=None)[源代码]#

从任意 JAX 兼容的标量函数创建 JAX ufunc。

参数
  • func (Callable[..., Any]) – 一个可调用对象,它接受 nin 个标量参数并返回 nout 个输出。

  • nin (int) – 指定标量输入数量的整数

  • nout (int) – 指定标量输出数量的整数

  • identity (Any | None) – (可选) 一个标量,指定操作的单位元(如果有)。

返回:

func 的 jax.numpy.ufunc 包装器。

返回类型:

wrapped

示例

这是一个创建类似于 jax.numpy.add 的 ufunc 的示例

>>> import operator
>>> add = frompyfunc(operator.add, nin=2, nout=1, identity=0)

现在所有标准的 jax.numpy.ufunc 方法都可用

>>> x = jnp.arange(4)
>>> add(x, 10)
Array([10, 11, 12, 13], dtype=int32)
>>> add.outer(x, x)
Array([[0, 1, 2, 3],
       [1, 2, 3, 4],
       [2, 3, 4, 5],
       [3, 4, 5, 6]], dtype=int32)
>>> add.reduce(x)
Array(6, dtype=int32)
>>> add.accumulate(x)
Array([0, 1, 3, 6], dtype=int32)
>>> add.at(x, 1, 10, inplace=False)
Array([ 0, 11,  2,  3], dtype=int32)