- class nip.trainers.spg.SpgTrainer(hyper_params: HyperParameters, scenario_instance: ScenarioInstance, settings: ExperimentSettings)[source]#
Stackelberg Policy Gradient [FCR20] trainer.
Implements an n-player version of Stackelberg Policy Gradient / Opponent-Shaping
- Parameters:
hyper_params (HyperParameters) – The parameters of the experiment.
device (TorchDevice) – The device to use for training.
Methods Summary
(hyper_params, scenario_instance, ...)_add_file_to_wandb_artifact
(artifact_name, ...)Add a file to a W&B artifact, creating the artifact if it doesn't exist.
Add observation normalization transforms to the environments.
Get the policy and value operators for the agents.
(stack)Build the context manager ExitStack for testing.
(stack)Build the context manager ExitStack for training.
Construct the data collectors, which generate rollouts from the environment.
(rollouts[, mean_loss_vals, train])Compute the statistics to log during training or testing.
Construct the loss module and the generalized advantage estimator.
Get metadata relevant to the current experiment, for saving the checkpoint.
(loss_module)Construct the optimizer for the loss module and the model parameter freezer.
([transform])Construct the replay buffer, which will store the rollouts.
Initialise the state of the trainer.
Load the experiment state from a checkpoint for the active run, if available.
Load the experiment state from a base run checkpoint.
Pretrain the agent bodies in isolation.
(iteration_context)Run the test loop.
(iteration_context)Run the training loop.
Run generic RL training and test loops.
Train the agents on data in the replay buffer.
Get the base directory for a checkpoint from a run ID.
Get the total number of iterations that the trainer will run for.
Load and set the experiment state from a checkpoint, if available.
([log])Save the state of the experiment to a checkpoint.
()Train the agents.
The names of the agents in the scenario.
The path to the directory containing the checkpoint.
The path to the checkpoint metadata file.
The path to the parameters file for the checkpoint.
The path to the checkpoint state file.
The maximum number of message rounds in the protocol.
The number of agents in the scenario.
The protocol handler for the experiment.
- __init__(hyper_params: HyperParameters, scenario_instance: ScenarioInstance, settings: ExperimentSettings)[source]#
- _add_file_to_wandb_artifact(artifact_name: str, artifact_type: str, file_path: Path)[source]#
Add a file to a W&B artifact, creating the artifact if it doesn’t exist.
If the artifact already exists, we add the file to the existing artifact, creating a new version.
- _add_normalization_transforms()[source]#
Add observation normalization transforms to the environments.
- _build_test_context(stack: ExitStack) list[ContextManager] [source]#
Build the context manager ExitStack for testing.
Takes as input an ExitStack and adds the appropriate context managers to it, then returns the context managers.
- Parameters:
stack (ExitStack) – The ExitStack to add the context managers to. Note that this is modified in-place.
- Returns:
context_managers (list[ContextManager]) – The target context managers to be used in the testing loop.
- _build_train_context(stack: ExitStack) list[ContextManager] [source]#
Build the context manager ExitStack for training.
Takes as input an ExitStack and adds the appropriate context managers to it, then returns the context managers.
- Parameters:
stack (ExitStack) – The ExitStack to add the context managers to. Note that this is modified in-place.
- Returns:
context_managers (list[ContextManager]) – The target context managers to be used in the training loop.
- _get_data_collectors() tuple[SyncDataCollector, SyncDataCollector] [source]#
Construct the data collectors, which generate rollouts from the environment.
Constructs a collector for both the train and the test environment.
- Returns:
train_collector (SyncDataCollector) – The train data collector.
test_collector (SyncDataCollector) – The test data collector.
- _get_log_stats(rollouts: TensorDictBase, mean_loss_vals: TensorDictBase | None = None, *, train=True) dict[str, float] [source]#
Compute the statistics to log during training or testing.
- Parameters:
rollouts (TensorDict) – The data sampled from the data collector.
mean_loss_vals (TensorDict, optional) – The average loss values.
train (bool, default=True) – Whether the statistics are for training or testing.
- Returns:
log_stats (dict[str, float]) – The statistics to log.
- _get_loss_module_and_gae() tuple[SpgLoss, GAE] [source]#
Construct the loss module and the generalized advantage estimator.
- Returns:
loss_module (SpgLoss) – The loss module.
gae (GAE) – The generalized advantage estimator.
- _get_metadata() dict [source]#
Get metadata relevant to the current experiment, for saving the checkpoint.
- Returns:
metadata (dict) – The metadata to record
- _get_optimizer_and_param_freezer(loss_module: Objective) tuple[Adam, ParamGroupFreezer] [source]#
Construct the optimizer for the loss module and the model parameter freezer.
- Parameters:
loss_module (Objective) – The loss module.
- Returns:
optimizer (torch.optim.Optimizer) – The optimizer.
param_group_freezer (ParamGroupFreezer) – The parameter dictionaries for each agent.
- _get_replay_buffer(transform: Transform | None = None) ReplayBuffer [source]#
Construct the replay buffer, which will store the rollouts.
- Parameters:
transform (Transform, optional) – The transform to apply to the data before storing it in the replay buffer.
- Returns:
ReplayBuffer – The replay buffer.
- _load_state_dict_from_active_run_checkpoint() dict [source]#
Load the experiment state from a checkpoint for the active run, if available.
- Returns:
state_dict (dict) – The state as a dictionary, loaded from W&B
- Raises:
CheckPointNotFoundError – If the checkpoint file is not found.
- _load_state_dict_from_base_run_checkpoint(version: str = 'latest') dict [source]#
Load the experiment state from a base run checkpoint.
- Parameters:
version (str, default="latest") – The version of the checkpoint to load. If “latest”, the latest checkpoint is loaded. Otherwise, the version should be a string representing the version of the checkpoint to load from W&B.
- Returns:
state_dict (dict) – The state as a dictionary, loaded from W&B
- Raises:
CheckPointNotFoundError – If the checkpoint file is not found.
- _pretrain_agents()[source]#
Pretrain the agent bodies in isolation.
This just uses the SoloAgentTrainer class.
- _run_test_loop(iteration_context: IterationContext)[source]#
Run the test loop.
- Parameters:
iteration_context (IterationContext) – The context used during testing. This controls the progress bar.
- _run_train_loop(iteration_context: IterationContext)[source]#
Run the training loop.
- Parameters:
iteration_context (IterationContext) – The context used during training. This controls the progress bar.
- _train_on_replay_buffer() TensorDict [source]#
Train the agents on data in the replay buffer.
- Returns:
mean_loss_vals (TensorDict) – The mean loss values over the training iterations.
- classmethod get_checkpoint_base_dir_from_run_id(run_id: str) Path [source]#
Get the base directory for a checkpoint from a run ID.
- Parameters:
run_id (str) – The run ID.
- Returns:
checkpoint_base_dir (Path) – The path to the base directory for the checkpoint.
- get_total_num_iterations() int [source]#
Get the total number of iterations that the trainer will run for.
This is the sum of the number of iterations declared by methods decorated with
.- Returns:
total_iterations (int) – The total number of iterations.
- load_and_set_state_from_checkpoint(from_base_run: bool = False, version: str = 'latest')[source]#
Load and set the experiment state from a checkpoint, if available.
- Parameters:
from_base_run (bool, default=False) – Whether to load the checkpoint artifact from the base run. By default, the artifact is loaded form the current active run.
version (str, default="latest") – The version of the checkpoint to load. If “latest”, the latest checkpoint is loaded. Otherwise, the version should be a string representing the version of the checkpoint to load from W&B. This must be “latest” if
is False.
- Raises:
CheckPointNotFoundError – If the checkpoint file is not found.