nip.utils.checkpoints.load_rollouts

Contents

nip.utils.checkpoints.load_rollouts#

nip.utils.checkpoints.load_rollouts(run_id: str, iterations: int | str, download_from_wandb: bool = True, wandb_project: str | None = None, wandb_entity: str | None = None, wandb_api: Api | None = None) NestedArrayDict[source]#
nip.utils.checkpoints.load_rollouts(run_id: str, iterations: list[int | str] | None, download_from_wandb: bool = True, wandb_project: str | None = None, wandb_entity: str | None = None, wandb_api: Api | None = None) list[NestedArrayDict]

Load the rollouts from a checkpoint.

The function can download the checkpoint from W&B, storing it locally in the checkpoint directory.

Parameters:
  • run_id (str) – The ID of the run to load the rollouts from.

  • iterations (int|str | list[int|str] | None, default=None) – The iteration(s) of the rollouts to load. If None, all iterations are loaded. If -1 the last iteration is loaded

  • download_from_wandb (bool, default=True) – Whether to download the rollouts from W&B. If False, the function will look for the rollouts in the local checkpoint directory.

  • wandb_project (str, optional) – The project of the wandb run. Must be provided if download_from_wandb is True.

  • wandb_entity (str, optional) – The entity of the wandb run. If not provided, the default entity will be used.

  • wandb_api (WandbApi, optional) – The wandb API instance to use. If not provided, a new instance will be created.

Returns:

rollouts (NestedArrayDict | list[NestedArrayDict]) – The rollouts loaded from the checkpoint. If iteration is None or a list, this will be a list of NestedArrayDicts. If iteration is an int or str, this will be a single NestedArrayDict.

Raises:

FileNotFoundError – If the rollouts directory is not found or the requested rollout file is not found.