jax.custom_batching.sequential_vmap#
- jax.custom_batching.sequential_vmap(f)[源代码]#
使用循环的
custom_vmap
的一个特殊情况。用
sequential_vmap
修饰的函数在批量处理时将在循环中按顺序调用。这对于本身不支持批量维度的函数很有用。例如
>>> @jax.custom_batching.sequential_vmap ... def f(x): ... jax.debug.print("{}", x) ... return x + 1 ... >>> jax.vmap(f)(jnp.arange(3)) 0 1 2 Array([1, 2, 3], dtype=int32)
其中 print 语句演示了此
vmap()
是使用循环生成的。有关更多详细信息,请参阅
custom_vmap
的文档。