nip.utils.torch.SimulateBatchDimsMixin

nip.utils.torch.SimulateBatchDimsMixin#

class nip.utils.torch.SimulateBatchDimsMixin[source]#

A mixin for simulating multiple batch dimensions.

Used for modules that don’t support multiple batch dimensions, but can be simulated by flattening the batch dimensions and then unflattening them after applying the module.

Classes that use this mixin should implement the feature_dims property.

Methods Summary

forward(x)

Apply the module to the input tensor, simulating multiple batch dimensions.

Attributes

feature_dims

The number of non-batch dimensions.

Methods

forward(x: Tensor) Tensor[source]#

Apply the module to the input tensor, simulating multiple batch dimensions.

Parameters:

x (Tensor) – The input tensor.

Returns:

out (Tensor) – The output tensor after applying the module.