jax.numpy.ufunc

jax.numpy.ufunc#

class jax.numpy.ufunc(func, /, nin, nout, *, name=None, nargs=None, identity=None, call=None, reduce=None, accumulate=None, at=None, reduceat=None)#

对数组进行逐元素操作的通用函数。

JAX 对 numpy.ufunc 的实现。

这是 JAX 支持的 NumPy ufunc API 实现的类。大多数用户永远不需要实例化 ufunc,而是将使用 jax.numpy 中预定义的 ufunc。

有关构建您自己的 ufunc 的方法,请参阅 jax.numpy.frompyfunc()

示例

通用函数是作用于广播数组的元素级函数,但它们也带有一些额外的属性和方法。

例如,考虑函数 jax.numpy.add。该对象充当一个函数,以元素级的方式将加法应用于广播数组。

>>> x = jnp.array([1, 2, 3, 4, 5])
>>> jnp.add(x, 1)
Array([2, 3, 4, 5, 6], dtype=int32)

每个 ufunc 对象包含一些描述其行为的属性。

>>> jnp.add.nin  # number of inputs
2
>>> jnp.add.nout  # number of outputs
1
>>> jnp.add.identity  # identity value, or None if no identity exists
0

jax.numpy.add 这样的二元 ufunc 包含一些方法,用于以不同的方式将函数应用于数组。

The outer() 方法将函数应用于输入数组值的成对外积。

>>> jnp.add.outer(x, x)
Array([[ 2,  3,  4,  5,  6],
       [ 3,  4,  5,  6,  7],
       [ 4,  5,  6,  7,  8],
       [ 5,  6,  7,  8,  9],
       [ 6,  7,  8,  9, 10]], dtype=int32)

The ufunc.reduce() 方法对数组执行归约。例如,jnp.add.reduce() 等同于 jnp.sum

>>> jnp.add.reduce(x)
Array(15, dtype=int32)

The ufunc.accumulate() 方法对数组执行累积归约。例如,jnp.add.accumulate() 等同于 jax.numpy.cumulative_sum()

>>> jnp.add.accumulate(x)
Array([ 1,  3,  6, 10, 15], dtype=int32)

The ufunc.at() 方法将函数应用于数组中的特定索引;对于 jnp.add,计算类似于 jax.lax.scatter_add()

>>> jnp.add.at(x, 0, 100, inplace=False)
Array([101,   2,   3,   4,   5], dtype=int32)

ufunc.reduceat() 方法在数组的指定索引之间执行多个 reduce 操作;对于 jnp.add,该操作类似于 jax.ops.segment_sum()

>>> jnp.add.reduceat(x, jnp.array([0, 2]))
Array([ 3, 12], dtype=int32)

在这种情况下,第一个元素是 x[0:2].sum(),第二个元素是 x[2:].sum()

参数:
  • func (Callable[..., Any])

  • nin (int)

  • nout (int)

  • name (str | None)

  • nargs (int | None)

  • identity (Any)

  • call (Callable[..., Any] | None)

  • reduce (Callable[..., Any] | None)

  • accumulate (Callable[..., Any] | None)

  • at (Callable[..., Any] | None)

  • reduceat (Callable[..., Any] | None)

__init__(func, /, nin, nout, *, name=None, nargs=None, identity=None, call=None, reduce=None, accumulate=None, at=None, reduceat=None)[source]#
参数:
  • func (Callable[..., Any])

  • nin (int)

  • nout (int)

  • name (str | None | None)

  • nargs (int | None | None)

  • identity (Any | None)

  • call (Callable[..., Any] | None | None)

  • reduce (Callable[..., Any] | None | None)

  • accumulate (Callable[..., Any] | None | None)

  • at (Callable[..., Any] | None | None)

  • reduceat (Callable[..., Any] | None | None)

方法

__init__(func, /, nin, nout, *[, name, ...])

accumulate(a[, axis, dtype, out])

从二元 ufunc 派生的累积操作。

at(a, indices[, b, inplace])

通过指定的单元或二元 ufunc 更新数组的元素。

outer(A, B, /)

将函数应用于 AB 中所有值的对。

reduce(a[, axis, dtype, out, keepdims, ...])

从二元函数派生的归约操作。

reduceat(a, indices[, axis, dtype, out])

通过二元 ufunc 归约指定索引之间的数组。

属性

identity

nargs

nin

nout