nip.rl_objectives.Objective#
- class nip.rl_objectives.Objective(*args, **kwargs)[source]#
Base class for all RL objectives.
Extends the LossModule class from TorchRL to allow multiple actions keys and normalise advantages.
The implementation is a bit of a hack. We change the _AcceptedKeys class dynamically to allow for multiple action keys.
See
torchrl.objectives.LossModule
for more detailsMethods Summary
_get_advantage
(tensordict)Get the advantage for a tensordict, normalising it if required.
_log_weight
(sample)Compute the log weight for the given TensorDict sample.
backward
(loss_vals)Perform the backward pass for the loss.
set_keys
(**kwargs)Set the keys of the input TensorDict that are used by this loss.
Attributes
SEP
TARGET_NET_WARNING
T_destination
action_keys
call_super_init
default_value_estimator
dump_patches
in_keys
out_keys
out_keys_source
tensor_keys
value_estimator
The value function blends in the reward and value estimate(s) from upcoming state(s)/state-action pair(s) into a target value estimate for the value network.
vmap_randomness
training
Methods
- _get_advantage(tensordict: TensorDictBase) Tensor [source]#
Get the advantage for a tensordict, normalising it if required.
- Parameters:
tensordict (TensorDictBase) – The input TensorDict.
- Returns:
advantage (torch.Tensor) – The normalised advantage.
- _log_weight(sample: TensorDictBase) tuple[Tensor, Tensor, Distribution] [source]#
Compute the log weight for the given TensorDict sample.
- Parameters:
sample (TensorDictBase) – The sample TensorDict.
- Returns:
log_prob (torch.Tensor) – The log probabilities of the sample
log_weight (torch.Tensor) – The log weight of the sample
dist (torch.distributions.Distribution) – The distribution used to compute the log weight.
- abstract backward(loss_vals: TensorDictBase)[source]#
Perform the backward pass for the loss.
- Parameters:
loss_vals (TensorDictBase) – The loss values.
- set_keys(**kwargs)[source]#
Set the keys of the input TensorDict that are used by this loss.
The keyword argument ‘action’ is treated specially. This should be an iterable of action keys. These are not validated against the set of accepted keys for this class. Instead, each is added to the set of accepted keys.
All other keyword arguments should match
self._AcceptedKeys
.- Parameters:
**kwargs – The keyword arguments to set.