nip.utils.torch.ParallelTensorDictModule#

class nip.utils.torch.ParallelTensorDictModule(*args, **kwargs)[source]#

Apply a module to each key of a TensorDict.

Parameters:
  • module (nn.Module) – The module to apply.

  • in_keys (NestedKey | Iterable[NestedKey]) – The keys to apply the module to.

  • out_keys (NestedKey | Iterable[NestedKey]) – The keys to store the output in.

Methods Summary

__init__(module, in_keys, out_keys)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(tensordict)

Apply the module to each key of the input TensorDict.

Attributes

T_destination

call_super_init

dump_patches

in_keys

out_keys

out_keys_source

training

Methods

__init__(module: Module, in_keys: str | Tuple[str, ...] | Iterable[str | Tuple[str, ...]], out_keys: str | Tuple[str, ...] | Iterable[str | Tuple[str, ...]])[source]#

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(tensordict: TensorDictBase) TensorDictBase[source]#

Apply the module to each key of the input TensorDict.

Parameters:

tensordict (TensorDictBase) – The input TensorDict.

Returns:

transformed_tensordict (TensorDictBase) – The input TensorDict with the module applied to each key.