nip.trainers.trainer_base.Trainer#

class nip.trainers.trainer_base.Trainer(hyper_params: HyperParameters, scenario_instance: ScenarioInstance, settings: ExperimentSettings)[source]#

Base class for all trainers.

Parameters:
  • hyper_params (HyperParameters) – The parameters of the experiment.

  • scenario_instance (ScenarioInstance) – The components of the experiment.

  • settings (ExperimentSettings) – The instance-specific settings of the experiment, like device, logging, etc.

Methods Summary

__init__(hyper_params, scenario_instance, ...)

_add_file_to_wandb_artifact(artifact_name, ...)

Add a file to a W&B artifact, creating the artifact if it doesn't exist.

_get_metadata()

Get metadata relevant to the current experiment, for saving the checkpoint.

_initialise_state()

Initialise the state of the experiment.

_load_state_dict_from_active_run_checkpoint()

Load the experiment state from a checkpoint for the active run, if available.

_load_state_dict_from_base_run_checkpoint([...])

Load the experiment state from a base run checkpoint.

get_checkpoint_base_dir_from_run_id(run_id)

Get the base directory for a checkpoint from a run ID.

get_total_num_iterations()

Get the total number of iterations that the trainer will run for.

load_and_set_state_from_checkpoint([...])

Load and set the experiment state from a checkpoint, if available.

save_checkpoint([log])

Save the state of the experiment to a checkpoint.

train()

Train the agents.

Attributes

agent_names

The names of the agents in the scenario.

checkpoint_base_dir

The path to the directory containing the checkpoint.

checkpoint_metadata_path

The path to the checkpoint metadata file.

checkpoint_params_path

The path to the parameters file for the checkpoint.

checkpoint_state_path

The path to the checkpoint state file.

max_message_rounds

The maximum number of message rounds in the protocol.

num_agents

The number of agents in the scenario.

protocol_handler

The protocol handler for the experiment.

Methods

__init__(hyper_params: HyperParameters, scenario_instance: ScenarioInstance, settings: ExperimentSettings)[source]#
_add_file_to_wandb_artifact(artifact_name: str, artifact_type: str, file_path: Path)[source]#

Add a file to a W&B artifact, creating the artifact if it doesn’t exist.

If the artifact already exists, we add the file to the existing artifact, creating a new version.

Parameters:
  • artifact_name (str) – The name of the artifact to add the file to. This should not contain an alias or version, as we always add to the latest version.

  • artifact_type (str) – The type of the artifact.

  • file_path (Path) – The path to the file to add to the artifact.

_get_metadata() dict[source]#

Get metadata relevant to the current experiment, for saving the checkpoint.

Returns:

metadata (dict) – The metadata to record

_initialise_state()[source]#

Initialise the state of the experiment.

This method should be implemented by subclasses to initialise the state of the experiment. This is called at the beginning of training when starting from scratch (i.e. not restoring from a checkpoint).

_load_state_dict_from_active_run_checkpoint() dict[source]#

Load the experiment state from a checkpoint for the active run, if available.

Returns:

state_dict (dict) – The state as a dictionary, loaded from W&B

Raises:

CheckPointNotFoundError – If the checkpoint file is not found.

_load_state_dict_from_base_run_checkpoint(version: str = 'latest') dict[source]#

Load the experiment state from a base run checkpoint.

Parameters:

version (str, default="latest") – The version of the checkpoint to load. If “latest”, the latest checkpoint is loaded. Otherwise, the version should be a string representing the version of the checkpoint to load from W&B.

Returns:

state_dict (dict) – The state as a dictionary, loaded from W&B

Raises:

CheckPointNotFoundError – If the checkpoint file is not found.

classmethod get_checkpoint_base_dir_from_run_id(run_id: str) Path[source]#

Get the base directory for a checkpoint from a run ID.

Parameters:

run_id (str) – The run ID.

Returns:

checkpoint_base_dir (Path) – The path to the base directory for the checkpoint.

get_total_num_iterations() int[source]#

Get the total number of iterations that the trainer will run for.

This is the sum of the number of iterations declared by methods decorated with attach_progress_bar.

Returns:

total_iterations (int) – The total number of iterations.

load_and_set_state_from_checkpoint(from_base_run: bool = False, version: str = 'latest')[source]#

Load and set the experiment state from a checkpoint, if available.

Parameters:
  • from_base_run (bool, default=False) – Whether to load the checkpoint artifact from the base run. By default, the artifact is loaded form the current active run.

  • version (str, default="latest") – The version of the checkpoint to load. If “latest”, the latest checkpoint is loaded. Otherwise, the version should be a string representing the version of the checkpoint to load from W&B. This must be “latest” if from_base_run is False.

Raises:

CheckPointNotFoundError – If the checkpoint file is not found.

save_checkpoint(log: bool = True)[source]#

Save the state of the experiment to a checkpoint.

Parameters:

log (bool, default=True) – Whether to log the checkpointing.

abstract train()[source]#

Train the agents.