jax.numpy.ufunc#
- class jax.numpy.ufunc(func, /, nin, nout, *, name=None, nargs=None, identity=None, call=None, reduce=None, accumulate=None, at=None, reduceat=None)#
对数组执行逐元素操作的通用函数。
numpy.ufunc
的 JAX 实现。这是一个用于 JAX 支持的 NumPy ufunc API 实现的类。大多数用户永远不需要实例化
ufunc
,而是使用jax.numpy
中预定义的 ufunc。要构建您自己的 ufunc,请参阅
jax.numpy.frompyfunc()
。示例
通用函数是逐元素应用于广播数组的函数,但它们还具有许多额外的属性和方法。
例如,考虑函数
jax.numpy.add
。该对象充当一个函数,以逐元素的方式将加法应用于广播数组>>> x = jnp.array([1, 2, 3, 4, 5]) >>> jnp.add(x, 1) Array([2, 3, 4, 5, 6], dtype=int32)
每个
ufunc
对象都包含许多描述其行为的属性>>> jnp.add.nin # number of inputs 2 >>> jnp.add.nout # number of outputs 1 >>> jnp.add.identity # identity value, or None if no identity exists 0
像
jax.numpy.add
这样的二元 ufunc 包含许多以不同方式将函数应用于数组的方法。outer()
方法将函数应用于输入数组值的成对外积>>> jnp.add.outer(x, x) Array([[ 2, 3, 4, 5, 6], [ 3, 4, 5, 6, 7], [ 4, 5, 6, 7, 8], [ 5, 6, 7, 8, 9], [ 6, 7, 8, 9, 10]], dtype=int32)
ufunc.reduce()
方法对数组执行归约。例如,jnp.add.reduce()
等同于jnp.sum
>>> jnp.add.reduce(x) Array(15, dtype=int32)
ufunc.accumulate()
方法对数组执行累积归约。例如,jnp.add.accumulate()
等同于jax.numpy.cumulative_sum()
>>> jnp.add.accumulate(x) Array([ 1, 3, 6, 10, 15], dtype=int32)
ufunc.at()
方法在数组的特定索引处应用该函数;对于jnp.add
,计算类似于jax.lax.scatter_add()
>>> jnp.add.at(x, 0, 100, inplace=False) Array([101, 2, 3, 4, 5], dtype=int32)
而
ufunc.reduceat()
方法在数组的指定索引之间执行多个reduce
操作;对于jnp.add
,该操作类似于jax.ops.segment_sum()
>>> jnp.add.reduceat(x, jnp.array([0, 2])) Array([ 3, 12], dtype=int32)
在这种情况下,第一个元素是
x[0:2].sum()
,第二个元素是x[2:].sum()
。- 参数:
- __init__(func, /, nin, nout, *, name=None, nargs=None, identity=None, call=None, reduce=None, accumulate=None, at=None, reduceat=None)[source]#
- 参数:
func (Callable[..., Any])
nin (int)
nout (int)
name (str | None | None)
nargs (int | None | None)
identity (Any | None)
call (Callable[..., Any] | None | None)
reduce (Callable[..., Any] | None | None)
accumulate (Callable[..., Any] | None | None)
at (Callable[..., Any] | None | None)
reduceat (Callable[..., Any] | None | None)
方法
__init__
(func, /, nin, nout, *[, name, ...])accumulate
(a[, axis, dtype, out])从二元 ufunc 派生的累积运算。
at
(a, indices[, b, inplace])通过指定的单元或二元 ufunc 更新数组的元素。
outer
(A, B, /)将函数应用于
A
和B
中的所有值对。reduce
(a[, axis, dtype, out, keepdims, ...])从二元函数派生的归约运算。
reduceat
(a, indices[, axis, dtype, out])通过二元 ufunc 归约指定索引之间的数组。
属性
identity
nargs
nin
nout