jax.numpy.frompyfunc#
- jax.numpy.frompyfunc(func, /, nin, nout, *, identity=None)[源代码]#
从任意 JAX 兼容的标量函数创建 JAX ufunc。
- 参数:
func (Callable[..., Any]) – 一个可调用对象,它接受 nin 个标量参数并返回 nout 个输出。
nin (int) – 指定标量输入数量的整数
nout (int) – 指定标量输出数量的整数
identity (Any | None) – (可选) 一个标量,指定操作的单位元(如果有)。
- 返回:
func 的 jax.numpy.ufunc 包装器。
- 返回类型:
wrapped
示例
这是一个创建类似于
jax.numpy.add
的 ufunc 的示例>>> import operator >>> add = frompyfunc(operator.add, nin=2, nout=1, identity=0)
现在所有标准的
jax.numpy.ufunc
方法都可用>>> x = jnp.arange(4) >>> add(x, 10) Array([10, 11, 12, 13], dtype=int32) >>> add.outer(x, x) Array([[0, 1, 2, 3], [1, 2, 3, 4], [2, 3, 4, 5], [3, 4, 5, 6]], dtype=int32) >>> add.reduce(x) Array(6, dtype=int32) >>> add.accumulate(x) Array([0, 1, 3, 6], dtype=int32) >>> add.at(x, 1, 10, inplace=False) Array([ 0, 11, 2, 3], dtype=int32)