nip.utils.plotting.rollouts.get_last_timestep_mask#
- nip.utils.plotting.rollouts.get_last_timestep_mask(rollouts: NestedArrayDict) ndarray[Any, dtype[_ScalarType_co]] [source]#
Compute a mask for the last timestep of each rollout.
The last timestep is defined as the timestep where the next done or next terminated flag is set to True, and the padding flag is not set.
Shapes
rollouts
is a a nested array dict with the following keys and shapes:(“next”, “done”) : (… round)
(“next”, “terminated”) : (… round)
(“padding”) : (… round)
The output
last_timestep_mask
is a boolean array with the same shape as the inputs: (… round)- Parameters:
rollouts (NestedArrayDict) – The rollouts to be analysed.
- Returns:
last_timestep_mask (NDArray) – A boolean array with the same shape as the inputs, where True indicates the last timestep of each rollout.