nip.utils.torch.TensorDictCat

nip.utils.torch.TensorDictCat#

class nip.utils.torch.TensorDictCat(*args, **kwargs)[source]#

Concatenate the keys of a TensorDict.

Parameters:
  • in_keys (Iterable[NestedKey]) – The keys to concatenate.

  • out_key (NestedKey) – The key of the concatenated tensor.

  • dim (int, default=0) – The dimension to concatenate over.

Methods Summary

__init__(in_keys, out_key[, dim])

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(tensordict)

Concatenate the keys of the input TensorDict.

Attributes

T_destination

call_super_init

dump_patches

in_keys

out_keys

out_keys_source

training

Methods

__init__(in_keys: Iterable[str | Tuple[str, ...]], out_key: str | Tuple[str, ...], dim=0)[source]#

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(tensordict: TensorDictBase) TensorDictBase[source]#

Concatenate the keys of the input TensorDict.

Parameters:

tensordict (TensorDictBase) – The input TensorDict.

Returns:

concatenated_tensordict (TensorDictBase) – The input TensorDict with the keys concatenated.