nip.utils.tensordict.get_key_batch_size

nip.utils.tensordict.get_key_batch_size#

nip.utils.tensordict.get_key_batch_size(td: TensorDictBase, key: str | Tuple[str, ...]) Size[source]#

Get the batch_size of the inner-most tensordict containing the key.

For instance, if the key is (“next”, “agents”, “done”), this function will return the batch size of the “agents” sub-tensordict of the “next” sub-tensordict.

Parameters:
  • td (TensorDictBase) – The tensordict containing the key

  • key (NestedKey) – The key for which to get the batch size

Returns:

batch_size (torch.Size) – The batch size of the inner-most tensordict containing the key