jax.flatten_util 模块

内容

jax.flatten_util 模块#

函数列表#

ravel_pytree(pytree)

将 pytree 中的数组展平(扁平化)为一个一维数组。