jax.numpy.piecewise#
- jax.numpy.piecewise(x, condlist, funclist, *args, **kw)[源代码]#
在域上分段求值函数。
JAX 的
numpy.piecewise()
实现,使用jax.lax.switch()
实现。注意
与
numpy.piecewise()
不同,jax.numpy.piecewise()
要求funclist
中的函数可被 JAX 跟踪,因为它通过jax.lax.switch()
实现。- 参数:
x (ArrayLike) – 输入值的数组。
condlist (
Array
| Sequence[ArrayLike]) – 布尔数组或布尔数组序列,对应于funclist
中的函数。如果是一个数组序列,则每个数组的长度必须与x
的长度匹配。funclist (
list
[ArrayLike | Callable[..., Array]]) – 数组或函数的列表;长度必须与condlist
相同,或者长度为len(condlist) + 1
,在这种情况下,最后一个条目是当所有条件都不为 True 时应用的默认值。或者,funclist
的条目可以是数值,在这种情况下,它们表示一个常数函数。args – 传递给
funclist
中每个函数的附加参数。kwargs – 传递给
funclist
中每个函数的附加关键字参数。
- 返回值:
一个数组,它是根据指定条件在
x
上评估函数的结果。- 返回类型:
另请参阅
jax.lax.switch()
: 基于索引在 *N* 个函数之间选择。jax.lax.cond()
: 基于布尔条件在两个函数之间选择。jax.numpy.where()
: 基于布尔掩码在两个结果之间选择。jax.lax.select()
: 基于布尔掩码在两个结果之间选择。jax.lax.select_n()
: 基于布尔掩码在 *N* 个结果之间选择。
示例
以下是一个函数的示例,该函数对于负值返回零,对于正值返回线性值
>>> x = jnp.array([-4, -3, -2, -1, 0, 1, 2, 3, 4])
>>> condlist = [x < 0, x >= 0] >>> funclist = [lambda x: 0 * x, lambda x: x] >>> jnp.piecewise(x, condlist, funclist) Array([0, 0, 0, 0, 0, 1, 2, 3, 4], dtype=int32)
funclist
也可以包含一个简单的标量值作为常数函数>>> condlist = [x < 0, x >= 0] >>> funclist = [0, lambda x: x] >>> jnp.piecewise(x, condlist, funclist) Array([0, 0, 0, 0, 0, 1, 2, 3, 4], dtype=int32)
您可以通过在
funclist
中附加一个额外的条件来指定默认值>>> condlist = [x < -1, x > 1] >>> funclist = [lambda x: 1 + x, lambda x: x - 1, 0] >>> jnp.piecewise(x, condlist, funclist) Array([-3, -2, -1, 0, 0, 0, 1, 2, 3], dtype=int32)
condlist
也可以是一个简单的标量条件数组,在这种情况下,关联的函数将应用于整个范围>>> condlist = jnp.array([False, True, False]) >>> funclist = [lambda x: x * 0, lambda x: x * 10, lambda x: x * 100] >>> jnp.piecewise(x, condlist, funclist) Array([-40, -30, -20, -10, 0, 10, 20, 30, 40], dtype=int32)