nip.trainers.solo_agent.SoloAgentTrainer#

class nip.trainers.solo_agent.SoloAgentTrainer(hyper_params: HyperParameters, scenario_instance: ScenarioInstance, settings: ExperimentSettings)[source]#

Trainer for training tensordict agents in isolation.

Parameters:
  • hyper_params (HyperParameters) – The parameters of the experiment.

  • scenario_instance (ComponentHolder) – The components of the experiment.

  • settings (ExperimentSettings) – The instance-specific settings of the experiment, like device, logging, etc.

Methods Summary

__init__(hyper_params, scenario_instance, ...)

_add_file_to_wandb_artifact(artifact_name, ...)

Add a file to a W&B artifact, creating the artifact if it doesn't exist.

_build_test_context(stack)

Build the context manager ExitStack for testing.

_build_train_context(stack)

Build the context manager ExitStack for training.

_get_metadata()

Get metadata relevant to the current experiment, for saving the checkpoint.

_initialise_state()

Initialise the state of the experiment.

_load_state_dict_from_active_run_checkpoint()

Load the experiment state from a checkpoint for the active run, if available.

_load_state_dict_from_base_run_checkpoint([...])

Load the experiment state from a base run checkpoint.

_run_test_loop(test_dataset, agents_params, ...)

Run the testing loop.

_run_train_loop(train_dataset, ...)

Run the training loop.

get_checkpoint_base_dir_from_run_id(run_id)

Get the base directory for a checkpoint from a run ID.

get_total_num_iterations()

Get the total number of iterations that the trainer will run for.

load_and_set_state_from_checkpoint([...])

Load and set the experiment state from a checkpoint, if available.

save_checkpoint([log])

Save the state of the experiment to a checkpoint.

train([as_pretraining])

Train the agents.

Attributes

agent_names

The names of the agents in the scenario.

checkpoint_base_dir

The path to the directory containing the checkpoint.

checkpoint_metadata_path

The path to the checkpoint metadata file.

checkpoint_params_path

The path to the parameters file for the checkpoint.

checkpoint_state_path

The path to the checkpoint state file.

max_message_rounds

The maximum number of message rounds in the protocol.

num_agents

The number of agents in the scenario.

protocol_handler

The protocol handler for the experiment.

Methods

__init__(hyper_params: HyperParameters, scenario_instance: ScenarioInstance, settings: ExperimentSettings)[source]#
_add_file_to_wandb_artifact(artifact_name: str, artifact_type: str, file_path: Path)[source]#

Add a file to a W&B artifact, creating the artifact if it doesn’t exist.

If the artifact already exists, we add the file to the existing artifact, creating a new version.

Parameters:
  • artifact_name (str) – The name of the artifact to add the file to. This should not contain an alias or version, as we always add to the latest version.

  • artifact_type (str) – The type of the artifact.

  • file_path (Path) – The path to the file to add to the artifact.

_build_test_context(stack: ExitStack) list[ContextManager][source]#

Build the context manager ExitStack for testing.

Takes as input an ExitStack and adds the appropriate context managers to it, then returns the context managers.

Parameters:

stack (ExitStack) – The ExitStack to add the context managers to. Note that this is modified in-place.

Returns:

context_managers (list[ContextManager]) – The target context managers to be used in the testing loop.

_build_train_context(stack: ExitStack) list[ContextManager][source]#

Build the context manager ExitStack for training.

Takes as input an ExitStack and adds the appropriate context managers to it, then returns the context managers.

Parameters:

stack (ExitStack) – The ExitStack to add the context managers to. Note that this is modified in-place.

Returns:

context_managers (list[ContextManager]) – The target context managers to be used in the training loop.

_get_metadata() dict[source]#

Get metadata relevant to the current experiment, for saving the checkpoint.

Returns:

metadata (dict) – The metadata to record

_initialise_state()[source]#

Initialise the state of the experiment.

This method should be implemented by subclasses to initialise the state of the experiment. This is called at the beginning of training when starting from scratch (i.e. not restoring from a checkpoint).

_load_state_dict_from_active_run_checkpoint() dict[source]#

Load the experiment state from a checkpoint for the active run, if available.

Returns:

state_dict (dict) – The state as a dictionary, loaded from W&B

Raises:

CheckPointNotFoundError – If the checkpoint file is not found.

_load_state_dict_from_base_run_checkpoint(version: str = 'latest') dict[source]#

Load the experiment state from a base run checkpoint.

Parameters:

version (str, default="latest") – The version of the checkpoint to load. If “latest”, the latest checkpoint is loaded. Otherwise, the version should be a string representing the version of the checkpoint to load from W&B.

Returns:

state_dict (dict) – The state as a dictionary, loaded from W&B

Raises:

CheckPointNotFoundError – If the checkpoint file is not found.

_run_test_loop(test_dataset: Dataset, agents_params: AgentsParameters, agents: dict[str, Agent], agent_models: dict[str, TensorDictSequential], as_pretraining: bool, logger: Logger)[source]#

Run the testing loop.

Parameters:
  • test_dataset (Dataset) – The dataset to test on.

  • agents_params (AgentsParameters) – The parameters of the agents.

  • agents (dict[str, Agent]) – A dictionary of the classes which hold the agent components.

  • agent_models (dict[str, TensorDictSequential]) – A dictionary of the actual models we’re testing.

  • as_pretraining (bool) – Whether we’re testing the agents as a pretraining step.

  • logger (logging.Logger) – The logger to use.

_run_train_loop(train_dataset: Dataset, agents_params: AgentsParameters, agents: dict[str, Agent], agent_models: dict[str, TensorDictSequential], as_pretraining: bool, torch_generator: Generator, iteration_context: IterationContext)[source]#

Run the training loop.

Parameters:
  • train_dataset (Dataset) – The dataset to train on.

  • agents_params (AgentsParameters) – The parameters of the agents.

  • agents (dict[str, Agent]) – A dictionary of the classes which hold the agent components.

  • agent_models (dict[str, TensorDictSequential]) – A dictionary of the actual models we’re training.

  • as_pretraining (bool) – Whether we’re training the agents as a pretraining step.

  • torch_generator (torch.Generator) – The random number generator to use.

  • iteration_context (IterationContext) – The context to use for the training loop, which handles the progress bar.

classmethod get_checkpoint_base_dir_from_run_id(run_id: str) Path[source]#

Get the base directory for a checkpoint from a run ID.

Parameters:

run_id (str) – The run ID.

Returns:

checkpoint_base_dir (Path) – The path to the base directory for the checkpoint.

get_total_num_iterations() int[source]#

Get the total number of iterations that the trainer will run for.

This is the sum of the number of iterations declared by methods decorated with attach_progress_bar.

Returns:

total_iterations (int) – The total number of iterations.

load_and_set_state_from_checkpoint(from_base_run: bool = False, version: str = 'latest')[source]#

Load and set the experiment state from a checkpoint, if available.

Parameters:
  • from_base_run (bool, default=False) – Whether to load the checkpoint artifact from the base run. By default, the artifact is loaded form the current active run.

  • version (str, default="latest") – The version of the checkpoint to load. If “latest”, the latest checkpoint is loaded. Otherwise, the version should be a string representing the version of the checkpoint to load from W&B. This must be “latest” if from_base_run is False.

Raises:

CheckPointNotFoundError – If the checkpoint file is not found.

save_checkpoint(log: bool = True)[source]#

Save the state of the experiment to a checkpoint.

Parameters:

log (bool, default=True) – Whether to log the checkpointing.

train(as_pretraining: bool = False)[source]#

Train the agents.

Parameters:

as_pretraining (bool, default=False) – Whether we’re training the agents as a pretraining step. This affects the output and what we log to W&B.