jax.custom_batching.sequential_vmap#

jax.custom_batching.sequential_vmap(f)[源代码]#

custom_vmap 的一个特例,它使用循环。

使用 sequential_vmap 修饰的函数在批处理时将在循环中按顺序调用。 这对于本身不支持批处理维度的函数很有用。

例如

>>> @jax.custom_batching.sequential_vmap
... def f(x):
...   jax.debug.print("{}", x)
...   return x + 1
...
>>> jax.vmap(f)(jnp.arange(3))
0
1
2
Array([1, 2, 3], dtype=int32)

其中的 print 语句演示了此 vmap() 是使用循环生成的。

有关更多详细信息,请参阅 custom_vmap 的文档。