nip.utils.torch.ParallelTensorDictModule#
- class nip.utils.torch.ParallelTensorDictModule(*args, **kwargs)[source]#
Apply a module to each key of a TensorDict.
- Parameters:
module (nn.Module) – The module to apply.
in_keys (NestedKey | Iterable[NestedKey]) – The keys to apply the module to.
out_keys (NestedKey | Iterable[NestedKey]) – The keys to store the output in.
Methods Summary
__init__
(module, in_keys, out_keys)Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward
(tensordict)Apply the module to each key of the input TensorDict.
Attributes
T_destination
call_super_init
dump_patches
in_keys
out_keys
out_keys_source
training
Methods
- __init__(module: Module, in_keys: str | Tuple[str, ...] | Iterable[str | Tuple[str, ...]], out_keys: str | Tuple[str, ...] | Iterable[str | Tuple[str, ...]])[source]#
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(tensordict: TensorDictBase) TensorDictBase [source]#
Apply the module to each key of the input TensorDict.
- Parameters:
tensordict (TensorDictBase) – The input TensorDict.
- Returns:
transformed_tensordict (TensorDictBase) – The input TensorDict with the module applied to each key.