nip.utils.torch.NormalizeOneHotMessageHistory#

class nip.utils.torch.NormalizeOneHotMessageHistory(*args, **kwargs)[source]#

Normalize the history of one-hot message exchanges.

Normalizes each component to have zero mean and unit variance, giving each possible length of messages the same weight.

The input is assumed to have some number of batch dimensions followed some number of structure dimensions, followed by the round dimension (these two are reversed when round_dim_last is False). The ‘structure’ dimensions are those that specify the structure of a data point, e.g. the height and width of an image. The input is assumed to be one-hot encoded across all the structure dimensions for each round where a message has been exchanged.

Shapes

Takes as input a TensorDict with key:

  • “x” with shape one of: - Float["... structure_dim_1 ... structure_dim_k round"] - Float["... round structure_dim_1 ... structure_dim_k"]

Parameters:
  • max_message_rounds (int) – The maximum length of the message history.

  • message_in_key (NestedKey, default="x") – The key containing the message history.

  • message_out_key (NestedKey, default="x_normalized") – The key to store the normalized message history.

  • num_structure_dims (int, default=1) – The number of feature dimensions to normalize over (see above).

  • round_dim_last (bool, default=True) – Whether the round dimension is the last dimension or whether it is located just before the structure dimensions.

Methods Summary

__init__(max_message_rounds[, ...])

Initialize internal Module state, shared by both nn.Module and ScriptModule.

_get_mean_and_std(x)

Get the mean and standard deviation for the structure shape of x.

forward(tensordict)

Normalize the message history.

to(*args, **kwargs)

Move the module to a new device or dtype.

Attributes

T_destination

call_super_init

dump_patches

in_keys

The keys of the input TensorDict.

out_keys

The keys of the output TensorDict.

out_keys_source

training

Methods

__init__(max_message_rounds: int, message_in_key: str | Tuple[str, ...] = 'x', message_out_key: str | Tuple[str, ...] = 'x_normalized', num_structure_dims: int = 1, round_dim_last: bool = True)[source]#

Initialize internal Module state, shared by both nn.Module and ScriptModule.

_get_mean_and_std(x: Tensor) tuple[Tensor, Tensor][source]#

Get the mean and standard deviation for the structure shape of x.

These are computed based only on the shape of the structure dimensions, so they can be cached and reused for tensors with the same structure shape.

Let n be the total size of the structure dimensions and m be the maximum number of message rounds. Then the mean and standard deviation are computed as follows:

\[\begin{split}\text{mean} = \frac 1 {n m} (m - 1, m - 2, \ldots, 0) \\ \text{std} = \frac 1 {n m} \sqrt{ ((m - 1) (n m - m + 1), (m - 2) (n m - m + 2), \ldots, 0) }\end{split}\]
Parameters:

x (Tensor) – The input tensor.

Returns:

  • mean (Tensor) – The mean for message histories with the structure shape of x.

  • std (Tensor) – The standard deviation for message histories with the structure shape of x.

forward(tensordict: TensorDictBase) TensorDictBase[source]#

Normalize the message history.

Parameters:

tensordict (TensorDictBase) – The input tensordict.

Returns:

normalized_tensordict (TensorDictBase) – The input tensordict with the message history normalized.

to(*args, **kwargs) NormalizeOneHotMessageHistory[source]#

Move the module to a new device or dtype.

Parameters:
  • *args – Positional arguments to pass to the to

  • **kwargs – Keyword arguments to pass to the to

Returns:

self (NormalizeOneHotMessageHistory) – The module, moved to the new device or dtype.