nip.trainers.trainer_base.Trainer#
- class nip.trainers.trainer_base.Trainer(hyper_params: HyperParameters, scenario_instance: ScenarioInstance, settings: ExperimentSettings)[source]#
Base class for all trainers.
- Parameters:
hyper_params (HyperParameters) – The parameters of the experiment.
scenario_instance (ScenarioInstance) – The components of the experiment.
settings (ExperimentSettings) – The instance-specific settings of the experiment, like device, logging, etc.
Methods Summary
__init__
(hyper_params, scenario_instance, ...)_add_file_to_wandb_artifact
(artifact_name, ...)Add a file to a W&B artifact, creating the artifact if it doesn't exist.
Get metadata relevant to the current experiment, for saving the checkpoint.
Initialise the state of the experiment.
Load the experiment state from a checkpoint for the active run, if available.
Load the experiment state from a base run checkpoint.
Get the base directory for a checkpoint from a run ID.
Get the total number of iterations that the trainer will run for.
Load and set the experiment state from a checkpoint, if available.
save_checkpoint
([log])Save the state of the experiment to a checkpoint.
train
()Train the agents.
Attributes
agent_names
The names of the agents in the scenario.
checkpoint_base_dir
The path to the directory containing the checkpoint.
checkpoint_metadata_path
The path to the checkpoint metadata file.
checkpoint_params_path
The path to the parameters file for the checkpoint.
checkpoint_state_path
The path to the checkpoint state file.
max_message_rounds
The maximum number of message rounds in the protocol.
num_agents
The number of agents in the scenario.
protocol_handler
The protocol handler for the experiment.
Methods
- __init__(hyper_params: HyperParameters, scenario_instance: ScenarioInstance, settings: ExperimentSettings)[source]#
- _add_file_to_wandb_artifact(artifact_name: str, artifact_type: str, file_path: Path)[source]#
Add a file to a W&B artifact, creating the artifact if it doesn’t exist.
If the artifact already exists, we add the file to the existing artifact, creating a new version.
- _get_metadata() dict [source]#
Get metadata relevant to the current experiment, for saving the checkpoint.
- Returns:
metadata (dict) – The metadata to record
- _initialise_state()[source]#
Initialise the state of the experiment.
This method should be implemented by subclasses to initialise the state of the experiment. This is called at the beginning of training when starting from scratch (i.e. not restoring from a checkpoint).
- _load_state_dict_from_active_run_checkpoint() dict [source]#
Load the experiment state from a checkpoint for the active run, if available.
- Returns:
state_dict (dict) – The state as a dictionary, loaded from W&B
- Raises:
CheckPointNotFoundError – If the checkpoint file is not found.
- _load_state_dict_from_base_run_checkpoint(version: str = 'latest') dict [source]#
Load the experiment state from a base run checkpoint.
- Parameters:
version (str, default="latest") – The version of the checkpoint to load. If “latest”, the latest checkpoint is loaded. Otherwise, the version should be a string representing the version of the checkpoint to load from W&B.
- Returns:
state_dict (dict) – The state as a dictionary, loaded from W&B
- Raises:
CheckPointNotFoundError – If the checkpoint file is not found.
- classmethod get_checkpoint_base_dir_from_run_id(run_id: str) Path [source]#
Get the base directory for a checkpoint from a run ID.
- Parameters:
run_id (str) – The run ID.
- Returns:
checkpoint_base_dir (Path) – The path to the base directory for the checkpoint.
- get_total_num_iterations() int [source]#
Get the total number of iterations that the trainer will run for.
This is the sum of the number of iterations declared by methods decorated with
attach_progress_bar
.- Returns:
total_iterations (int) – The total number of iterations.
- load_and_set_state_from_checkpoint(from_base_run: bool = False, version: str = 'latest')[source]#
Load and set the experiment state from a checkpoint, if available.
- Parameters:
from_base_run (bool, default=False) – Whether to load the checkpoint artifact from the base run. By default, the artifact is loaded form the current active run.
version (str, default="latest") – The version of the checkpoint to load. If “latest”, the latest checkpoint is loaded. Otherwise, the version should be a string representing the version of the checkpoint to load from W&B. This must be “latest” if
from_base_run
is False.
- Raises:
CheckPointNotFoundError – If the checkpoint file is not found.