nip.utils.torch.TensorDictCat#
- class nip.utils.torch.TensorDictCat(*args, **kwargs)[source]#
Concatenate the keys of a TensorDict.
- Parameters:
in_keys (Iterable[NestedKey]) – The keys to concatenate.
out_key (NestedKey) – The key of the concatenated tensor.
dim (int, default=0) – The dimension to concatenate over.
Methods Summary
__init__
(in_keys, out_key[, dim])Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward
(tensordict)Concatenate the keys of the input TensorDict.
Attributes
T_destination
call_super_init
dump_patches
in_keys
out_keys
out_keys_source
training
Methods
- __init__(in_keys: Iterable[str | Tuple[str, ...]], out_key: str | Tuple[str, ...], dim=0)[source]#
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(tensordict: TensorDictBase) TensorDictBase [source]#
Concatenate the keys of the input TensorDict.
- Parameters:
tensordict (TensorDictBase) – The input TensorDict.
- Returns:
concatenated_tensordict (TensorDictBase) – The input TensorDict with the keys concatenated.