nip.utils.torch.TensorDictPrint

nip.utils.torch.TensorDictPrint#

class nip.utils.torch.TensorDictPrint(*args, **kwargs)[source]#

Print information about an input tensordict.

Parameters:
  • keys (NestedKey | Iterable[NestedKey]) – The keys to print.

  • name (str, default=None) – The name of the tensordict, which will be printed before the keys.

  • print_nan_proportion (bool, default=False) – Whether to print the proportion of NaN values in the tensors.

Methods Summary

__init__(keys[, name, print_nan_proportion])

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(tensordict)

Print the information about the tensors in the input tensordict.

Attributes

T_destination

call_super_init

dump_patches

in_keys

out_keys

out_keys_source

training

Methods

__init__(keys: str | Tuple[str, ...] | Iterable[str | Tuple[str, ...]], name: str | None = None, print_nan_proportion: bool = False)[source]#

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(tensordict: TensorDictBase) TensorDictBase[source]#

Print the information about the tensors in the input tensordict.

Parameters:

tensordict (TensorDictBase) – The input tensordict.

Returns:

tensordict (TensorDictBase) – The input tensordict, unchanged.