nip.utils.tensordict.get_key_batch_size#
- nip.utils.tensordict.get_key_batch_size(td: TensorDictBase, key: str | Tuple[str, ...]) Size [source]#
Get the batch_size of the inner-most tensordict containing the key.
For instance, if the key is (“next”, “agents”, “done”), this function will return the batch size of the “agents” sub-tensordict of the “next” sub-tensordict.
- Parameters:
td (TensorDictBase) – The tensordict containing the key
key (NestedKey) – The key for which to get the batch size
- Returns:
batch_size (torch.Size) – The batch size of the inner-most tensordict containing the key