jax.nn 模块#

用于神经网络库的常用函数。

激活函数#

relu

修正线性单元激活函数。

relu6

修正线性单元 6 激活函数。

sigmoid(x)

Sigmoid 激活函数。

softplus(x)

Softplus 激活函数。

sparse_plus(x)

稀疏加函数。

sparse_sigmoid(x)

稀疏 Sigmoid 激活函数。

soft_sign(x)

Soft-sign 激活函数。

silu(x)

SiLU (又名 swish) 激活函数。

swish(x)

SiLU (又名 swish) 激活函数。

log_sigmoid(x)

Log-sigmoid 激活函数。

leaky_relu(x[, negative_slope])

Leaky ReLU (带泄漏的线性整流单元) 激活函数。

hard_sigmoid(x)

Hard Sigmoid 激活函数。

hard_silu(x)

Hard SiLU (swish) 激活函数

hard_swish(x)

Hard SiLU (swish) 激活函数

hard_tanh(x)

Hard \(\mathrm{tanh}\) 激活函数。

elu(x[, alpha])

指数线性单元激活函数。

celu(x[, alpha])

连续可微的指数线性单元激活函数。

selu(x)

缩放的指数线性单元激活函数。

gelu(x[, approximate])

高斯误差线性单元激活函数。

glu(x[, axis])

门控线性单元激活函数。

squareplus(x[, b])

Squareplus 激活函数。

mish(x)

Mish 激活函数。

其他函数#

softmax(x[, axis, where, initial])

Softmax 函数。

log_softmax(x[, axis, where, initial])

Log-Softmax 函数。

logsumexp()

Log-sum-exp 规约。

standardize(x[, axis, mean, variance, ...])

通过减去 mean 并除以 \(\sqrt{\mathrm{variance}}\) 来标准化数组。

one_hot(x, num_classes, *[, dtype, axis])

对给定索引进行 one-hot 编码。

dot_product_attention(query, key, value[, ...])

缩放的点积注意力函数。