jax.nn.squareplus

内容

jax.nn.squareplus#

jax.nn.squareplus(x, b=4)[source]#

Squareplus 激活函数。

计算逐元素函数

\[\mathrm{squareplus}(x) = \frac{x + \sqrt{x^2 + b}}{2}\]

https://arxiv.org/abs/2112.11687 中所述。

参数:
  • x (ArrayLike) – 输入数组

  • b (ArrayLike) – 平滑度参数

返回类型:

Array