nip.trainers.trainer_base.Trainer#
- class nip.trainers.trainer_base.Trainer(hyper_params: HyperParameters, scenario_instance: ScenarioInstance, settings: ExperimentSettings)[source]#
Base class for all trainers.
- Parameters:
hyper_params (HyperParameters) – The parameters of the experiment.
scenario_instance (ScenarioInstance) – The components of the experiment.
settings (ExperimentSettings) – The instance-specific settings of the experiment, like device, logging, etc.
Methods Summary
__init__(hyper_params, scenario_instance, ...)_add_file_to_wandb_artifact(artifact_name, ...)Add a file to a W&B artifact, creating the artifact if it doesn't exist.
Get metadata relevant to the current experiment, for saving the checkpoint.
Initialise the state of the experiment.
Load the experiment state from a checkpoint for the active run, if available.
Load the experiment state from a base run checkpoint.
Get the base directory for a checkpoint from a run ID.
Get the total number of iterations that the trainer will run for.
Load and set the experiment state from a checkpoint, if available.
save_checkpoint([log])Save the state of the experiment to a checkpoint.
train()Train the agents.
Attributes
agent_namesThe names of the agents in the scenario.
checkpoint_base_dirThe path to the directory containing the checkpoint.
checkpoint_metadata_pathThe path to the checkpoint metadata file.
checkpoint_params_pathThe path to the parameters file for the checkpoint.
checkpoint_state_pathThe path to the checkpoint state file.
max_message_roundsThe maximum number of message rounds in the protocol.
num_agentsThe number of agents in the scenario.
protocol_handlerThe protocol handler for the experiment.
trainer_typeThe type of trainer this is.
Methods
- __init__(hyper_params: HyperParameters, scenario_instance: ScenarioInstance, settings: ExperimentSettings)[source]#
- _add_file_to_wandb_artifact(artifact_name: str, artifact_type: str, file_path: Path)[source]#
Add a file to a W&B artifact, creating the artifact if it doesn’t exist.
If the artifact already exists, we add the file to the existing artifact, creating a new version.
- _get_metadata() dict[source]#
Get metadata relevant to the current experiment, for saving the checkpoint.
- Returns:
metadata (dict) – The metadata to record
- _initialise_state()[source]#
Initialise the state of the experiment.
This method should be implemented by subclasses to initialise the state of the experiment. This is called at the beginning of training when starting from scratch (i.e. not restoring from a checkpoint).
- _load_state_dict_from_active_run_checkpoint() dict[source]#
Load the experiment state from a checkpoint for the active run, if available.
- Returns:
state_dict (dict) – The state as a dictionary, loaded from W&B
- Raises:
CheckPointNotFoundError – If the checkpoint file is not found.
- _load_state_dict_from_base_run_checkpoint(version: str = 'latest') dict[source]#
Load the experiment state from a base run checkpoint.
- Parameters:
version (str, default="latest") – The version of the checkpoint to load. If “latest”, the latest checkpoint is loaded. Otherwise, the version should be a string representing the version of the checkpoint to load from W&B.
- Returns:
state_dict (dict) – The state as a dictionary, loaded from W&B
- Raises:
CheckPointNotFoundError – If the checkpoint file is not found.
- classmethod get_checkpoint_base_dir_from_run_id(run_id: str) Path[source]#
Get the base directory for a checkpoint from a run ID.
- Parameters:
run_id (str) – The run ID.
- Returns:
checkpoint_base_dir (Path) – The path to the base directory for the checkpoint.
- get_total_num_iterations() int[source]#
Get the total number of iterations that the trainer will run for.
This is the sum of the number of iterations declared by methods decorated with
attach_progress_bar.- Returns:
total_iterations (int) – The total number of iterations.
- load_and_set_state_from_checkpoint(from_base_run: bool = False, version: str = 'latest')[source]#
Load and set the experiment state from a checkpoint, if available.
- Parameters:
from_base_run (bool, default=False) – Whether to load the checkpoint artifact from the base run. By default, the artifact is loaded form the current active run.
version (str, default="latest") – The version of the checkpoint to load. If “latest”, the latest checkpoint is loaded. Otherwise, the version should be a string representing the version of the checkpoint to load from W&B. This must be “latest” if
from_base_runis False.
- Raises:
CheckPointNotFoundError – If the checkpoint file is not found.