nip.rl_objectives.ClipPPOLossImproved#

class nip.rl_objectives.ClipPPOLossImproved(*args, **kwargs)[source]#

Clipped PPO loss which allows multiple actions keys and normalises advantages.

See torchrl.objectives.ClipPPOLoss for more details.

Methods Summary

_get_advantage(tensordict)

Get the advantage for a tensordict, normalising it if required.

_log_weight(sample)

Compute the log weight for the given TensorDict sample.

_loss_critic(tensordict)

Get the critic loss without the clip fraction.

_set_entropy_and_critic_losses(tensordict, ...)

Set the entropy and critic losses in the output TensorDict.

_set_ess(num_batch_dims, td_out, log_weight)

Set the ESS in the output TensorDict, for logging.

_set_in_keys()

backward(loss_vals)

Perform the backward pass for the loss.

forward(tensordict)

Compute the loss for the PPO algorithm with clipping.

set_keys(**kwargs)

Set the keys of the input TensorDict that are used by this loss.

Attributes

SEP

TARGET_NET_WARNING

T_destination

_cached_critic_network_params_detached

_clip_bounds

action_keys

call_super_init

default_keys

default_value_estimator

dump_patches

functional

in_keys

out_keys

out_keys_source

tensor_keys

value_estimator

The value function blends in the reward and value estimate(s) from upcoming state(s)/state-action pair(s) into a target value estimate for the value network.

vmap_randomness

training

Methods

_get_advantage(tensordict: TensorDictBase) Tensor[source]#

Get the advantage for a tensordict, normalising it if required.

Parameters:

tensordict (TensorDictBase) – The input TensorDict.

Returns:

advantage (torch.Tensor) – The normalised advantage.

_log_weight(sample: TensorDictBase) tuple[Tensor, Tensor, Distribution][source]#

Compute the log weight for the given TensorDict sample.

Parameters:

sample (TensorDictBase) – The sample TensorDict.

Returns:

  • log_prob (torch.Tensor) – The log probabilities of the sample

  • log_weight (torch.Tensor) – The log weight of the sample

  • dist (torch.distributions.Distribution) – The distribution used to compute the log weight.

_loss_critic(tensordict: TensorDictBase) Tensor[source]#

Get the critic loss without the clip fraction.

TorchRL’s loss_critic method returns a tuple with the critic loss and the clip fraction. This method returns only the critic loss.

_set_entropy_and_critic_losses(tensordict: TensorDictBase, td_out: TensorDictBase, dist: CompositeCategoricalDistribution)[source]#

Set the entropy and critic losses in the output TensorDict.

Parameters:
  • tensordict (TensorDictBase) – The input TensorDict.

  • td_out (TensorDictBase) – The output TensorDict, which will be modified in place.

  • dist (CompositeCategoricalDistribution) – The distribution used to compute the log weight.

_set_ess(num_batch_dims: int, td_out: TensorDictBase, log_weight: Tensor)[source]#

Set the ESS in the output TensorDict, for logging.

Parameters:
  • num_batch_dims (int) – The number of batch dimensions.

  • td_out (TensorDictBase) – The output TensorDict, which will be modified in place.

  • log_weight (Tensor) – The log weights.

_set_in_keys()[source]#
backward(loss_vals: TensorDictBase)[source]#

Perform the backward pass for the loss.

Parameters:

loss_vals (TensorDictBase) – The loss values.

forward(tensordict: TensorDictBase) TensorDictBase[source]#

Compute the loss for the PPO algorithm with clipping.

Parameters:

tensordict (TensorDictBase) – The input TensorDict.

Returns:

td_out (TensorDictBase) – The output TensorDict containing the losses.

set_keys(**kwargs)[source]#

Set the keys of the input TensorDict that are used by this loss.

The keyword argument ‘action’ is treated specially. This should be an iterable of action keys. These are not validated against the set of accepted keys for this class. Instead, each is added to the set of accepted keys.

All other keyword arguments should match self._AcceptedKeys.

Parameters:

**kwargs – The keyword arguments to set.