jax.experimental.sparse.bcoo_reshape

jax.experimental.sparse.bcoo_reshape#

jax.experimental.sparse.bcoo_reshape(mat, *, new_sizes, dimensions=None)[source]#

Sparse implementation of {func}`jax.lax.reshape`.

Parameters:
  • operand – BCOO array to be reshaped.

  • new_sizes (Sequence[int]) – sequence of integers specifying the resulting shape. The size of the final array must match the size of the input. This must be specified such that batch, sparse, and dense dimensions do not mix.

  • dimensions (Sequence[int] | None | None) – optional sequence of integers specifying the permutation order of the input shape. If specified, the length must match operand.shape. Additionally, dimensions must only permute among like dimensions of mat: batch, sparse, and dense dimensions cannot be permuted.

  • mat (BCOO)

Returns:

reshaped array.

Return type:

out