nip.utils.distributions.CompositeCategoricalDistribution#

class nip.utils.distributions.CompositeCategoricalDistribution(**kwargs)[source]#

A composition of categorical distributions.

Allows specifying the parameters of the categorical distributions either as logits or as probabilities.

The log_prob method is reimplemented with the following changes:

  • The log-probability can be stored in a different key (specified by the “log_prob_key” parameter)

  • It only computes stores the total log-probability, not the individual ones

  • It doesn’t reduce all non-batch dimensions in the log-probability

  • It has a method to compute the entropy of the distribution

Parameter names must be strings ending in “_logits” or “_probs”. However, the suffix-stripped can be changed by passing a key transform function or a lookup table. For example, to specify the parameters of a categorical distribution over key ("agents", "action") using logits, you can pass the following:

>>> CompositeCategoricalDistribution(
...     action_logits=...,
...     key_transform=lambda x: ("agents", x)
... )
Parameters:
  • **categorical_params (dict[str, Tensor]) – The parameters of the categorical distributions. Each key is the name of a categorical parameter appended with “_logits” or “_probs” and each value is a Tensor containing the logits or probabilities of the categorical distribution.

  • key_transform (callable[[str], NestedKey] | dict[str, NestedKey], optional) – A function that transforms the keys of the categorical parameters. If a dict is given, it is used as a lookup table. If a callable is given, it is applied to each key. Defaults to the identity function.

  • log_prob_key (NestedKey, default="sample_log_prob") – The tensordict key to use for the log-probability of the sample

Methods Summary

__init__(**kwargs)

entropy(batch_size)

Compute the entropy of the composite distribution.

log_prob(sample)

Compute the log probability of a sample for the composite distribution.

Attributes

arg_constraints

Returns a dictionary from argument names to Constraint objects that should be satisfied by each argument of this distribution.

batch_shape

Returns the shape over which parameters are batched.

event_shape

Returns the shape of a single sample (without batching).

has_enumerate_support

has_rsample

mean

Returns the mean of the distribution.

mode

Returns the mode of the distribution.

stddev

Returns the standard deviation of the distribution.

support

Returns a Constraint object representing this distribution's support.

variance

Returns the variance of the distribution.

dists

Methods

__init__(**kwargs)[source]#
entropy(batch_size: int | tuple[int, ...]) Tensor[source]#

Compute the entropy of the composite distribution.

Parameters:

batch_size (int | tuple[int, ...]) – The common batch size of the categorical distributions. The output tensor will have this shape.

Returns:

entropy (float) – The entropy of the composite distribution

log_prob(sample: TensorDictBase) TensorDictBase[source]#

Compute the log probability of a sample for the composite distribution.

Adapted from tensordict.nn.distributions.CompositeDistribution.log_prob.

The shape of the log-probability tensor is the batch size of the inner-most tensordict in which is lives. E.g. for the key ("agents", "sample_log_prob") this will be the batch size of the “agents” sub-tensordict.

Parameters:

sample (TensorDictBase) – A tensordict containing the sample

Returns:

updated_sample (TensorDictBase) – The sample tensordict updated with the log probability of the sample