jax.nn 模块

jax.nn 模块#

神经网络库的常用函数。

激活函数#

relu

线性整流单元激活函数。

relu6

线性整流单元 6 激活函数。

sigmoid(x)

Sigmoid 激活函数。

softplus(x)

Softplus 激活函数。

sparse_plus(x)

稀疏加函数。

sparse_sigmoid(x)

稀疏 sigmoid 激活函数。

soft_sign(x)

软符号激活函数。

silu(x)

SiLU(又名 swish)激活函数。

swish(x)

SiLU(又名 swish)激活函数。

log_sigmoid(x)

对数 sigmoid 激活函数。

leaky_relu(x[, negative_slope])

泄漏整流线性单元激活函数。

hard_sigmoid(x)

硬 sigmoid 激活函数。

hard_silu(x)

硬 SiLU (swish) 激活函数

hard_swish(x)

硬 SiLU (swish) 激活函数

hard_tanh(x)

\(\mathrm{tanh}\) 激活函数。

elu(x[, alpha])

指数线性单元激活函数。

celu(x[, alpha])

连续可微指数线性单元激活。

selu(x)

缩放指数线性单元激活。

gelu(x[, approximate])

高斯误差线性单元激活函数。

glu(x[, axis])

门控线性单元激活函数。

squareplus(x[, b])

Squareplus 激活函数。

mish(x)

Mish 激活函数。

其他函数#

softmax(x[, axis, where, initial])

Softmax 函数。

log_softmax(x[, axis, where, initial])

对数 Softmax 函数。

logsumexp()

对数求和指数归约。

standardize(x[, axis, mean, variance, ...])

通过减去 mean 并除以 \(\sqrt{\mathrm{variance}}\) 来标准化数组。

one_hot(x, num_classes, *[, dtype, axis])

对给定的索引进行独热编码。

dot_product_attention(query, key, value[, ...])

缩放点积注意力函数。