nip.trainers.vanilla_ppo.VanillaPpoTrainer#

class nip.trainers.vanilla_ppo.VanillaPpoTrainer(hyper_params: HyperParameters, scenario_instance: ScenarioInstance, settings: ExperimentSettings)[source]#

Vanilla Proximal Policy Optimization trainer.

Implements a multi-agent PPO algorithm, specifically IPPO, since the value estimator is not shared between agents.

Parameters:
  • hyper_params (HyperParameters) – The parameters of the experiment.

  • scenario_instance (ScenarioInstance) – The components of the experiment.

  • device (TorchDevice) – The device to use for training.

Methods Summary

__init__(hyper_params, scenario_instance, ...)

_add_file_to_wandb_artifact(artifact_name, ...)

Add a file to a W&B artifact, creating the artifact if it doesn't exist.

_add_normalization_transforms()

Add observation normalization transforms to the environments.

_build_operators()

Get the policy and value operators for the agents.

_build_test_context(stack)

Build the context manager ExitStack for testing.

_build_train_context(stack)

Build the context manager ExitStack for training.

_get_data_collectors()

Construct the data collectors, which generate rollouts from the environment.

_get_log_stats(rollouts[, mean_loss_vals, train])

Compute the statistics to log during training or testing.

_get_loss_module_and_gae()

Construct the loss module and the generalized advantage estimator.

_get_metadata()

Get metadata relevant to the current experiment, for saving the checkpoint.

_get_optimizer_and_param_freezer(loss_module)

Construct the optimizer for the loss module and the model parameter freezer.

_get_replay_buffer([transform])

Construct the replay buffer, which will store the rollouts.

_initialise_state()

Initialise the state of the trainer.

_load_state_dict_from_active_run_checkpoint()

Load the experiment state from a checkpoint for the active run, if available.

_load_state_dict_from_base_run_checkpoint([...])

Load the experiment state from a base run checkpoint.

_pretrain_agents()

Pretrain the agent bodies in isolation.

_run_test_loop(iteration_context)

Run the test loop.

_run_train_loop(iteration_context)

Run the training loop.

_train_and_test()

Run generic RL training and test loops.

_train_on_replay_buffer()

Train the agents on data in the replay buffer.

get_checkpoint_base_dir_from_run_id(run_id)

Get the base directory for a checkpoint from a run ID.

get_total_num_iterations()

Get the total number of iterations that the trainer will run for.

load_and_set_state_from_checkpoint([...])

Load and set the experiment state from a checkpoint, if available.

save_checkpoint([log])

Save the state of the experiment to a checkpoint.

train()

Train the agents.

Attributes

PRETRAINED_MODEL_CACHE_PARAM_KEYS

agent_names

The names of the agents in the scenario.

checkpoint_base_dir

The path to the directory containing the checkpoint.

checkpoint_metadata_path

The path to the checkpoint metadata file.

checkpoint_params_path

The path to the parameters file for the checkpoint.

checkpoint_state_path

The path to the checkpoint state file.

max_message_rounds

The maximum number of message rounds in the protocol.

num_agents

The number of agents in the scenario.

protocol_handler

The protocol handler for the experiment.

policy_operator

value_operator

Methods

__init__(hyper_params: HyperParameters, scenario_instance: ScenarioInstance, settings: ExperimentSettings)[source]#
_add_file_to_wandb_artifact(artifact_name: str, artifact_type: str, file_path: Path)[source]#

Add a file to a W&B artifact, creating the artifact if it doesn’t exist.

If the artifact already exists, we add the file to the existing artifact, creating a new version.

Parameters:
  • artifact_name (str) – The name of the artifact to add the file to. This should not contain an alias or version, as we always add to the latest version.

  • artifact_type (str) – The type of the artifact.

  • file_path (Path) – The path to the file to add to the artifact.

_add_normalization_transforms()[source]#

Add observation normalization transforms to the environments.

_build_operators()[source]#

Get the policy and value operators for the agents.

_build_test_context(stack: ExitStack) list[ContextManager][source]#

Build the context manager ExitStack for testing.

Takes as input an ExitStack and adds the appropriate context managers to it, then returns the context managers.

Parameters:

stack (ExitStack) – The ExitStack to add the context managers to. Note that this is modified in-place.

Returns:

context_managers (list[ContextManager]) – The target context managers to be used in the testing loop.

_build_train_context(stack: ExitStack) list[ContextManager][source]#

Build the context manager ExitStack for training.

Takes as input an ExitStack and adds the appropriate context managers to it, then returns the context managers.

Parameters:

stack (ExitStack) – The ExitStack to add the context managers to. Note that this is modified in-place.

Returns:

context_managers (list[ContextManager]) – The target context managers to be used in the training loop.

_get_data_collectors() tuple[SyncDataCollector, SyncDataCollector][source]#

Construct the data collectors, which generate rollouts from the environment.

Constructs a collector for both the train and the test environment.

Returns:

  • train_collector (SyncDataCollector) – The train data collector.

  • test_collector (SyncDataCollector) – The test data collector.

_get_log_stats(rollouts: TensorDictBase, mean_loss_vals: TensorDictBase | None = None, *, train=True) dict[str, float][source]#

Compute the statistics to log during training or testing.

Parameters:
  • rollouts (TensorDict) – The data sampled from the data collector.

  • mean_loss_vals (TensorDict, optional) – The average loss values.

  • train (bool, default=True) – Whether the statistics are for training or testing.

Returns:

log_stats (dict[str, float]) – The statistics to log.

_get_loss_module_and_gae() tuple[ClipPPOLossImproved, GAE][source]#

Construct the loss module and the generalized advantage estimator.

Returns:

  • loss_module (ClipPPOLossMultipleActions) – The loss module.

  • gae (GAE) – The generalized advantage estimator.

_get_metadata() dict[source]#

Get metadata relevant to the current experiment, for saving the checkpoint.

Returns:

metadata (dict) – The metadata to record

_get_optimizer_and_param_freezer(loss_module: Objective) tuple[Adam, ParamGroupFreezer][source]#

Construct the optimizer for the loss module and the model parameter freezer.

Parameters:

loss_module (Objective) – The loss module.

Returns:

  • optimizer (torch.optim.Optimizer) – The optimizer.

  • param_group_freezer (ParamGroupFreezer) – The parameter dictionaries for each agent.

_get_replay_buffer(transform: Transform | None = None) ReplayBuffer[source]#

Construct the replay buffer, which will store the rollouts.

Parameters:

transform (Transform, optional) – The transform to apply to the data before storing it in the replay buffer.

Returns:

ReplayBuffer – The replay buffer.

_initialise_state()[source]#

Initialise the state of the trainer.

_load_state_dict_from_active_run_checkpoint() dict[source]#

Load the experiment state from a checkpoint for the active run, if available.

Returns:

state_dict (dict) – The state as a dictionary, loaded from W&B

Raises:

CheckPointNotFoundError – If the checkpoint file is not found.

_load_state_dict_from_base_run_checkpoint(version: str = 'latest') dict[source]#

Load the experiment state from a base run checkpoint.

Parameters:

version (str, default="latest") – The version of the checkpoint to load. If “latest”, the latest checkpoint is loaded. Otherwise, the version should be a string representing the version of the checkpoint to load from W&B.

Returns:

state_dict (dict) – The state as a dictionary, loaded from W&B

Raises:

CheckPointNotFoundError – If the checkpoint file is not found.

_pretrain_agents()[source]#

Pretrain the agent bodies in isolation.

This just uses the SoloAgentTrainer class.

_run_test_loop(iteration_context: IterationContext)[source]#

Run the test loop.

Parameters:

iteration_context (IterationContext) – The context used during testing. This controls the progress bar.

_run_train_loop(iteration_context: IterationContext)[source]#

Run the training loop.

Parameters:

iteration_context (IterationContext) – The context used during training. This controls the progress bar.

_train_and_test()[source]#

Run generic RL training and test loops.

_train_on_replay_buffer() TensorDict[source]#

Train the agents on data in the replay buffer.

Returns:

mean_loss_vals (TensorDict) – The mean loss values over the training iterations.

classmethod get_checkpoint_base_dir_from_run_id(run_id: str) Path[source]#

Get the base directory for a checkpoint from a run ID.

Parameters:

run_id (str) – The run ID.

Returns:

checkpoint_base_dir (Path) – The path to the base directory for the checkpoint.

get_total_num_iterations() int[source]#

Get the total number of iterations that the trainer will run for.

This is the sum of the number of iterations declared by methods decorated with attach_progress_bar.

Returns:

total_iterations (int) – The total number of iterations.

load_and_set_state_from_checkpoint(from_base_run: bool = False, version: str = 'latest')[source]#

Load and set the experiment state from a checkpoint, if available.

Parameters:
  • from_base_run (bool, default=False) – Whether to load the checkpoint artifact from the base run. By default, the artifact is loaded form the current active run.

  • version (str, default="latest") – The version of the checkpoint to load. If “latest”, the latest checkpoint is loaded. Otherwise, the version should be a string representing the version of the checkpoint to load from W&B. This must be “latest” if from_base_run is False.

Raises:

CheckPointNotFoundError – If the checkpoint file is not found.

save_checkpoint(log: bool = True)[source]#

Save the state of the experiment to a checkpoint.

Parameters:

log (bool, default=True) – Whether to log the checkpointing.

train()[source]#

Train the agents.