nip.utils.torch.flatten_batch_dims#
- nip.utils.torch.flatten_batch_dims(x: Tensor, num_batch_dims: int) Tensor [source]#
Return a new view of a tensor with the batch dimensions flattened.
- Parameters:
x (Tensor) – The input tensor. Has shape
(B1, B2, ..., Bn, D1, D2, ..., Dm)
, where n is the number of batch dimensionsnum_batch_dims
.num_batch_dims (int) – The number of batch dimensions to flatten.
- Returns:
x_flattened (Tensor) – The input tensor with the batch dimensions flattened. Has shape
(B, D1, D2, ..., Dm)
, whereB = B1 * B2 * ... * Bn
.