nip.rl_objectives.PPOLossImproved#
- class nip.rl_objectives.PPOLossImproved(*args, **kwargs)[source]#
Base PPO loss class which allows multiple actions keys and normalises advantages.
See
torchrl.objectives.PPOLossfor more detailsMethods Summary
_get_advantage(tensordict)Get the advantage for a tensordict, normalising it if required.
_log_weight(sample)Compute the log weight for the given TensorDict sample.
_loss_critic(tensordict)Get the critic loss without the clip fraction.
_set_entropy_and_critic_losses(tensordict, ...)Set the entropy and critic losses in the output TensorDict.
backward(loss_vals)Perform the backward pass for the loss.
set_keys(**kwargs)Set the keys of the input TensorDict that are used by this loss.
Attributes
SEPTARGET_NET_WARNINGT_destination_cached_critic_network_params_detachedaction_keyscall_super_initdefault_keysdefault_value_estimatordeterministic_sampling_modedump_patchesfunctionalWhether the module is functional.
in_keysout_keysout_keys_sourcetensor_keysvalue_estimatorThe value function blends in the reward and value estimate(s) from upcoming state(s)/state-action pair(s) into a target value estimate for the value network.
vmap_randomnessVmap random mode.
actor_networkcritic_networkactor_network_paramscritic_network_paramstarget_actor_network_paramstarget_critic_network_paramstrainingMethods
- _get_advantage(tensordict: TensorDictBase) Tensor[source]#
Get the advantage for a tensordict, normalising it if required.
- Parameters:
tensordict (TensorDictBase) – The input TensorDict.
- Returns:
advantage (torch.Tensor) – The normalised advantage.
- _log_weight(sample: TensorDictBase) tuple[Tensor, Tensor, Distribution][source]#
Compute the log weight for the given TensorDict sample.
- Parameters:
sample (TensorDictBase) – The sample TensorDict.
- Returns:
log_prob (torch.Tensor) – The log probabilities of the sample
log_weight (torch.Tensor) – The log weight of the sample
dist (torch.distributions.Distribution) – The distribution used to compute the log weight.
- _loss_critic(tensordict: TensorDictBase) Tensor[source]#
Get the critic loss without the clip fraction.
TorchRL’s
loss_criticmethod returns a tuple with the critic loss and the clip fraction. This method returns only the critic loss.
- _set_entropy_and_critic_losses(tensordict: TensorDictBase, td_out: TensorDictBase, dist: CompositeCategoricalDistribution)[source]#
Set the entropy and critic losses in the output TensorDict.
- Parameters:
tensordict (TensorDictBase) – The input TensorDict.
td_out (TensorDictBase) – The output TensorDict, which will be modified in place.
dist (CompositeCategoricalDistribution) – The distribution used to compute the log weight.
- backward(loss_vals: TensorDictBase)[source]#
Perform the backward pass for the loss.
- Parameters:
loss_vals (TensorDictBase) – The loss values.
- set_keys(**kwargs)[source]#
Set the keys of the input TensorDict that are used by this loss.
The keyword argument ‘action’ is treated specially. This should be an iterable of action keys. These are not validated against the set of accepted keys for this class. Instead, each is added to the set of accepted keys.
All other keyword arguments should match
self._AcceptedKeys.- Parameters:
**kwargs – The keyword arguments to set.