nip.utils.torch.NormalizeOneHotMessageHistory#
- class nip.utils.torch.NormalizeOneHotMessageHistory(*args, **kwargs)[source]#
Normalize the history of one-hot message exchanges.
Normalizes each component to have zero mean and unit variance, giving each possible length of messages the same weight.
The input is assumed to have some number of batch dimensions followed some number of structure dimensions, followed by the round dimension (these two are reversed when
round_dim_last
is False). The ‘structure’ dimensions are those that specify the structure of a data point, e.g. the height and width of an image. The input is assumed to be one-hot encoded across all the structure dimensions for each round where a message has been exchanged.Shapes
Takes as input a TensorDict with key:
“x” with shape one of: -
Float["... structure_dim_1 ... structure_dim_k round"]
-Float["... round structure_dim_1 ... structure_dim_k"]
- Parameters:
max_message_rounds (int) – The maximum length of the message history.
message_in_key (NestedKey, default="x") – The key containing the message history.
message_out_key (NestedKey, default="x_normalized") – The key to store the normalized message history.
num_structure_dims (int, default=1) – The number of feature dimensions to normalize over (see above).
round_dim_last (bool, default=True) – Whether the round dimension is the last dimension or whether it is located just before the structure dimensions.
Methods Summary
__init__
(max_message_rounds[, ...])Initialize internal Module state, shared by both nn.Module and ScriptModule.
Get the mean and standard deviation for the structure shape of
x
.forward
(tensordict)Normalize the message history.
to
(*args, **kwargs)Move the module to a new device or dtype.
Attributes
T_destination
call_super_init
dump_patches
in_keys
The keys of the input TensorDict.
out_keys
The keys of the output TensorDict.
out_keys_source
training
Methods
- __init__(max_message_rounds: int, message_in_key: str | Tuple[str, ...] = 'x', message_out_key: str | Tuple[str, ...] = 'x_normalized', num_structure_dims: int = 1, round_dim_last: bool = True)[source]#
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- _get_mean_and_std(x: Tensor) tuple[Tensor, Tensor] [source]#
Get the mean and standard deviation for the structure shape of
x
.These are computed based only on the shape of the structure dimensions, so they can be cached and reused for tensors with the same structure shape.
Let
n
be the total size of the structure dimensions andm
be the maximum number of message rounds. Then the mean and standard deviation are computed as follows:\[\begin{split}\text{mean} = \frac 1 {n m} (m - 1, m - 2, \ldots, 0) \\ \text{std} = \frac 1 {n m} \sqrt{ ((m - 1) (n m - m + 1), (m - 2) (n m - m + 2), \ldots, 0) }\end{split}\]- Parameters:
x (Tensor) – The input tensor.
- Returns:
mean (Tensor) – The mean for message histories with the structure shape of
x
.std (Tensor) – The standard deviation for message histories with the structure shape of
x
.
- forward(tensordict: TensorDictBase) TensorDictBase [source]#
Normalize the message history.
- Parameters:
tensordict (TensorDictBase) – The input tensordict.
- Returns:
normalized_tensordict (TensorDictBase) – The input tensordict with the message history normalized.
- to(*args, **kwargs) NormalizeOneHotMessageHistory [source]#
Move the module to a new device or dtype.
- Parameters:
*args – Positional arguments to pass to the
to
**kwargs – Keyword arguments to pass to the
to
- Returns:
self (NormalizeOneHotMessageHistory) – The module, moved to the new device or dtype.