nip.utils.distributions.CompositeCategoricalDistribution#
- class nip.utils.distributions.CompositeCategoricalDistribution(**kwargs)[source]#
A composition of categorical distributions.
Allows specifying the parameters of the categorical distributions either as logits or as probabilities.
The
log_prob
method is reimplemented with the following changes:The log-probability can be stored in a different key (specified by the “log_prob_key” parameter)
It only computes stores the total log-probability, not the individual ones
It doesn’t reduce all non-batch dimensions in the log-probability
It has a method to compute the entropy of the distribution
Parameter names must be strings ending in “_logits” or “_probs”. However, the suffix-stripped can be changed by passing a key transform function or a lookup table. For example, to specify the parameters of a categorical distribution over key
("agents", "action")
using logits, you can pass the following:>>> CompositeCategoricalDistribution( ... action_logits=..., ... key_transform=lambda x: ("agents", x) ... )
- Parameters:
**categorical_params (dict[str, Tensor]) – The parameters of the categorical distributions. Each key is the name of a categorical parameter appended with “_logits” or “_probs” and each value is a Tensor containing the logits or probabilities of the categorical distribution.
key_transform (callable[[str], NestedKey] | dict[str, NestedKey], optional) – A function that transforms the keys of the categorical parameters. If a dict is given, it is used as a lookup table. If a callable is given, it is applied to each key. Defaults to the identity function.
log_prob_key (NestedKey, default="sample_log_prob") – The tensordict key to use for the log-probability of the sample
Methods Summary
__init__
(**kwargs)entropy
(batch_size)Compute the entropy of the composite distribution.
log_prob
(sample)Compute the log probability of a sample for the composite distribution.
Attributes
arg_constraints
Returns a dictionary from argument names to
Constraint
objects that should be satisfied by each argument of this distribution.batch_shape
Returns the shape over which parameters are batched.
event_shape
Returns the shape of a single sample (without batching).
has_enumerate_support
has_rsample
mean
Returns the mean of the distribution.
mode
Returns the mode of the distribution.
stddev
Returns the standard deviation of the distribution.
support
Returns a
Constraint
object representing this distribution's support.variance
Returns the variance of the distribution.
dists
Methods
- entropy(batch_size: int | tuple[int, ...]) Tensor [source]#
Compute the entropy of the composite distribution.
- log_prob(sample: TensorDictBase) TensorDictBase [source]#
Compute the log probability of a sample for the composite distribution.
Adapted from
tensordict.nn.distributions.CompositeDistribution.log_prob
.The shape of the log-probability tensor is the batch size of the inner-most tensordict in which is lives. E.g. for the key
("agents", "sample_log_prob")
this will be the batch size of the “agents” sub-tensordict.- Parameters:
sample (TensorDictBase) – A tensordict containing the sample
- Returns:
updated_sample (TensorDictBase) – The sample tensordict updated with the log probability of the sample