jax.numpy.frompyfunc

内容

jax.numpy.frompyfunc#

jax.numpy.frompyfunc(func, /, nin, nout, *, identity=None)[source]#

从任意兼容 JAX 的标量函数创建 JAX ufunc。

参数:
  • func (Callable[..., Any]) – 一个可调用对象,它接受 nin 个标量参数并返回 nout 个输出。

  • nin (int) – 指定标量输入数量的整数。

  • nout (int) – 指定标量输出数量的整数。

  • identity (Any | None) – (可选) 指定操作的标识的标量,如果有的话。

返回值:

func 的 jax.numpy.ufunc 包装器。

返回类型:

已包装