jax.custom_batching.custom_vmap.def_vmap#

custom_vmap.def_vmap(vmap_rule)[源代码]#

为此 custom_vmap 函数定义 vmap 规则。

参数:

vmap_rule (Callable[..., tuple[Any, Any]]) – 一个实现 vmap 规则的函数。此函数应接受以下参数:(1) 一个整数 axis_size 作为其第一个参数,(2) 一个与函数输入具有相同结构的布尔值 pytree,指定每个参数是否被批处理,以及 (3) 批处理的参数。它应该返回一个批处理输出的元组和一个与输出具有相同结构的布尔值 pytree,指定每个输出元素是否被批处理。有关一些示例,请参阅 jax.custom_batching.custom_vmap() 的文档。

返回:

此方法传递规则,返回未更改的 vmap_rule

返回类型:

Callable[…, tuple[Any, Any]]