jax.numpy.ufunc#
- class jax.numpy.ufunc(func, /, nin, nout, *, name=None, nargs=None, identity=None, call=None, reduce=None, accumulate=None, at=None, reduceat=None)#
对数组进行逐元素操作的通用函数。
JAX 对
numpy.ufunc
的实现。这是 JAX 支持的 NumPy ufunc API 实现的类。大多数用户永远不需要实例化
ufunc
,而是将使用jax.numpy
中预定义的 ufunc。有关构建您自己的 ufunc 的方法,请参阅
jax.numpy.frompyfunc()
。示例
通用函数是作用于广播数组的元素级函数,但它们也带有一些额外的属性和方法。
例如,考虑函数
jax.numpy.add
。该对象充当一个函数,以元素级的方式将加法应用于广播数组。>>> x = jnp.array([1, 2, 3, 4, 5]) >>> jnp.add(x, 1) Array([2, 3, 4, 5, 6], dtype=int32)
每个
ufunc
对象包含一些描述其行为的属性。>>> jnp.add.nin # number of inputs 2 >>> jnp.add.nout # number of outputs 1 >>> jnp.add.identity # identity value, or None if no identity exists 0
像
jax.numpy.add
这样的二元 ufunc 包含一些方法,用于以不同的方式将函数应用于数组。The
outer()
方法将函数应用于输入数组值的成对外积。>>> jnp.add.outer(x, x) Array([[ 2, 3, 4, 5, 6], [ 3, 4, 5, 6, 7], [ 4, 5, 6, 7, 8], [ 5, 6, 7, 8, 9], [ 6, 7, 8, 9, 10]], dtype=int32)
The
ufunc.reduce()
方法对数组执行归约。例如,jnp.add.reduce()
等同于jnp.sum
>>> jnp.add.reduce(x) Array(15, dtype=int32)
The
ufunc.accumulate()
方法对数组执行累积归约。例如,jnp.add.accumulate()
等同于jax.numpy.cumulative_sum()
>>> jnp.add.accumulate(x) Array([ 1, 3, 6, 10, 15], dtype=int32)
The
ufunc.at()
方法将函数应用于数组中的特定索引;对于jnp.add
,计算类似于jax.lax.scatter_add()
>>> jnp.add.at(x, 0, 100, inplace=False) Array([101, 2, 3, 4, 5], dtype=int32)
而
ufunc.reduceat()
方法在数组的指定索引之间执行多个reduce
操作;对于jnp.add
,该操作类似于jax.ops.segment_sum()
>>> jnp.add.reduceat(x, jnp.array([0, 2])) Array([ 3, 12], dtype=int32)
在这种情况下,第一个元素是
x[0:2].sum()
,第二个元素是x[2:].sum()
。- 参数:
- __init__(func, /, nin, nout, *, name=None, nargs=None, identity=None, call=None, reduce=None, accumulate=None, at=None, reduceat=None)[source]#
- 参数:
func (Callable[..., Any])
nin (int)
nout (int)
name (str | None | None)
nargs (int | None | None)
identity (Any | None)
call (Callable[..., Any] | None | None)
reduce (Callable[..., Any] | None | None)
accumulate (Callable[..., Any] | None | None)
at (Callable[..., Any] | None | None)
reduceat (Callable[..., Any] | None | None)
方法
__init__
(func, /, nin, nout, *[, name, ...])accumulate
(a[, axis, dtype, out])从二元 ufunc 派生的累积操作。
at
(a, indices[, b, inplace])通过指定的单元或二元 ufunc 更新数组的元素。
outer
(A, B, /)将函数应用于
A
和B
中所有值的对。reduce
(a[, axis, dtype, out, keepdims, ...])从二元函数派生的归约操作。
reduceat
(a, indices[, axis, dtype, out])通过二元 ufunc 归约指定索引之间的数组。
属性
identity
nargs
nin
nout