jax.numpy.frompyfunc#
- jax.numpy.frompyfunc(func, /, nin, nout, *, identity=None)[源代码]#
从任意 JAX 兼容的标量函数创建一个 JAX ufunc。
- 参数:
- 返回:
func 的 jax.numpy.ufunc 包装器。
- 返回类型:
wrapped
示例
这里有一个创建类似于
jax.numpy.add
的 ufunc 的例子。>>> import operator >>> add = frompyfunc(operator.add, nin=2, nout=1, identity=0)
现在所有标准的
jax.numpy.ufunc
方法都是可用的。>>> x = jnp.arange(4) >>> add(x, 10) Array([10, 11, 12, 13], dtype=int32) >>> add.outer(x, x) Array([[0, 1, 2, 3], [1, 2, 3, 4], [2, 3, 4, 5], [3, 4, 5, 6]], dtype=int32) >>> add.reduce(x) Array(6, dtype=int32) >>> add.accumulate(x) Array([0, 1, 3, 6], dtype=int32) >>> add.at(x, 1, 10, inplace=False) Array([ 0, 11, 2, 3], dtype=int32)