nip.utils.plotting.rollouts.get_last_timestep_mask

nip.utils.plotting.rollouts.get_last_timestep_mask#

nip.utils.plotting.rollouts.get_last_timestep_mask(rollouts: NestedArrayDict) ndarray[Any, dtype[_ScalarType_co]][source]#

Compute a mask for the last timestep of each rollout.

The last timestep is defined as the timestep where the next done or next terminated flag is set to True, and the padding flag is not set.

Shapes

rollouts is a a nested array dict with the following keys and shapes:

  • (“next”, “done”) : (… round)

  • (“next”, “terminated”) : (… round)

  • (“padding”) : (… round)

The output last_timestep_mask is a boolean array with the same shape as the inputs: (… round)

Parameters:

rollouts (NestedArrayDict) – The rollouts to be analysed.

Returns:

last_timestep_mask (NDArray) – A boolean array with the same shape as the inputs, where True indicates the last timestep of each rollout.