nip.utils.torch.flatten_batch_dims

nip.utils.torch.flatten_batch_dims#

nip.utils.torch.flatten_batch_dims(x: Tensor, num_batch_dims: int) Tensor[source]#

Return a new view of a tensor with the batch dimensions flattened.

Parameters:
  • x (Tensor) – The input tensor. Has shape (B1, B2, ..., Bn, D1, D2, ..., Dm), where n is the number of batch dimensions num_batch_dims.

  • num_batch_dims (int) – The number of batch dimensions to flatten.

Returns:

x_flattened (Tensor) – The input tensor with the batch dimensions flattened. Has shape (B, D1, D2, ..., Dm), where B = B1 * B2 * ... * Bn.