nip.utils.checkpoints.load_rollouts#
- nip.utils.checkpoints.load_rollouts(run_id: str, iterations: int | str, download_from_wandb: bool = True, wandb_project: str | None = None, wandb_entity: str | None = None, wandb_api: Api | None = None) NestedArrayDict [source]#
- nip.utils.checkpoints.load_rollouts(run_id: str, iterations: list[int | str] | None, download_from_wandb: bool = True, wandb_project: str | None = None, wandb_entity: str | None = None, wandb_api: Api | None = None) list[NestedArrayDict]
Load the rollouts from a checkpoint.
The function can download the checkpoint from W&B, storing it locally in the checkpoint directory.
- Parameters:
run_id (str) – The ID of the run to load the rollouts from.
iterations (int|str | list[int|str] | None, default=None) – The iteration(s) of the rollouts to load. If None, all iterations are loaded. If -1 the last iteration is loaded
download_from_wandb (bool, default=True) – Whether to download the rollouts from W&B. If False, the function will look for the rollouts in the local checkpoint directory.
wandb_project (str, optional) – The project of the wandb run. Must be provided if
download_from_wandb
is True.wandb_entity (str, optional) – The entity of the wandb run. If not provided, the default entity will be used.
wandb_api (WandbApi, optional) – The wandb API instance to use. If not provided, a new instance will be created.
- Returns:
rollouts (NestedArrayDict | list[NestedArrayDict]) – The rollouts loaded from the checkpoint. If
iteration
is None or a list, this will be a list of NestedArrayDicts. Ifiteration
is anint
orstr
, this will be a single NestedArrayDict.- Raises:
FileNotFoundError – If the rollouts directory is not found or the requested rollout file is not found.