nip.trainers.spg.SpgTrainer#
- class nip.trainers.spg.SpgTrainer(hyper_params: HyperParameters, scenario_instance: ScenarioInstance, settings: ExperimentSettings)[source]#
Stackelberg Policy Gradient [FCR20] trainer.
Implements an n-player version of Stackelberg Policy Gradient / Opponent-Shaping
- Parameters:
hyper_params (HyperParameters) – The parameters of the experiment.
device (TorchDevice) – The device to use for training.
Methods Summary
__init__
(hyper_params, scenario_instance, ...)_add_file_to_wandb_artifact
(artifact_name, ...)Add a file to a W&B artifact, creating the artifact if it doesn't exist.
Add observation normalization transforms to the environments.
Get the policy and value operators for the agents.
_build_test_context
(stack)Build the context manager ExitStack for testing.
_build_train_context
(stack)Build the context manager ExitStack for training.
Construct the data collectors, which generate rollouts from the environment.
_get_log_stats
(rollouts[, mean_loss_vals, train])Compute the statistics to log during training or testing.
Construct the loss module and the generalized advantage estimator.
Get metadata relevant to the current experiment, for saving the checkpoint.
_get_optimizer_and_param_freezer
(loss_module)Construct the optimizer for the loss module and the model parameter freezer.
_get_replay_buffer
([transform])Construct the replay buffer, which will store the rollouts.
Initialise the state of the trainer.
Load the experiment state from a checkpoint for the active run, if available.
Load the experiment state from a base run checkpoint.
Pretrain the agent bodies in isolation.
_run_test_loop
(iteration_context)Run the test loop.
_run_train_loop
(iteration_context)Run the training loop.
Run generic RL training and test loops.
Train the agents on data in the replay buffer.
Get the base directory for a checkpoint from a run ID.
Get the total number of iterations that the trainer will run for.
Load and set the experiment state from a checkpoint, if available.
save_checkpoint
([log])Save the state of the experiment to a checkpoint.
train
()Train the agents.
Attributes
PRETRAINED_MODEL_CACHE_PARAM_KEYS
agent_names
The names of the agents in the scenario.
checkpoint_base_dir
The path to the directory containing the checkpoint.
checkpoint_metadata_path
The path to the checkpoint metadata file.
checkpoint_params_path
The path to the parameters file for the checkpoint.
checkpoint_state_path
The path to the checkpoint state file.
max_message_rounds
The maximum number of message rounds in the protocol.
num_agents
The number of agents in the scenario.
protocol_handler
The protocol handler for the experiment.
policy_operator
value_operator
Methods
- __init__(hyper_params: HyperParameters, scenario_instance: ScenarioInstance, settings: ExperimentSettings)[source]#
- _add_file_to_wandb_artifact(artifact_name: str, artifact_type: str, file_path: Path)[source]#
Add a file to a W&B artifact, creating the artifact if it doesn’t exist.
If the artifact already exists, we add the file to the existing artifact, creating a new version.
- _add_normalization_transforms()[source]#
Add observation normalization transforms to the environments.
- _build_test_context(stack: ExitStack) list[ContextManager] [source]#
Build the context manager ExitStack for testing.
Takes as input an ExitStack and adds the appropriate context managers to it, then returns the context managers.
- Parameters:
stack (ExitStack) – The ExitStack to add the context managers to. Note that this is modified in-place.
- Returns:
context_managers (list[ContextManager]) – The target context managers to be used in the testing loop.
- _build_train_context(stack: ExitStack) list[ContextManager] [source]#
Build the context manager ExitStack for training.
Takes as input an ExitStack and adds the appropriate context managers to it, then returns the context managers.
- Parameters:
stack (ExitStack) – The ExitStack to add the context managers to. Note that this is modified in-place.
- Returns:
context_managers (list[ContextManager]) – The target context managers to be used in the training loop.
- _get_data_collectors() tuple[SyncDataCollector, SyncDataCollector] [source]#
Construct the data collectors, which generate rollouts from the environment.
Constructs a collector for both the train and the test environment.
- Returns:
train_collector (SyncDataCollector) – The train data collector.
test_collector (SyncDataCollector) – The test data collector.
- _get_log_stats(rollouts: TensorDictBase, mean_loss_vals: TensorDictBase | None = None, *, train=True) dict[str, float] [source]#
Compute the statistics to log during training or testing.
- Parameters:
rollouts (TensorDict) – The data sampled from the data collector.
mean_loss_vals (TensorDict, optional) – The average loss values.
train (bool, default=True) – Whether the statistics are for training or testing.
- Returns:
log_stats (dict[str, float]) – The statistics to log.
- _get_loss_module_and_gae() tuple[SpgLoss, GAE] [source]#
Construct the loss module and the generalized advantage estimator.
- Returns:
loss_module (SpgLoss) – The loss module.
gae (GAE) – The generalized advantage estimator.
- _get_metadata() dict [source]#
Get metadata relevant to the current experiment, for saving the checkpoint.
- Returns:
metadata (dict) – The metadata to record
- _get_optimizer_and_param_freezer(loss_module: Objective) tuple[Adam, ParamGroupFreezer] [source]#
Construct the optimizer for the loss module and the model parameter freezer.
- Parameters:
loss_module (Objective) – The loss module.
- Returns:
optimizer (torch.optim.Optimizer) – The optimizer.
param_group_freezer (ParamGroupFreezer) – The parameter dictionaries for each agent.
- _get_replay_buffer(transform: Transform | None = None) ReplayBuffer [source]#
Construct the replay buffer, which will store the rollouts.
- Parameters:
transform (Transform, optional) – The transform to apply to the data before storing it in the replay buffer.
- Returns:
ReplayBuffer – The replay buffer.
- _load_state_dict_from_active_run_checkpoint() dict [source]#
Load the experiment state from a checkpoint for the active run, if available.
- Returns:
state_dict (dict) – The state as a dictionary, loaded from W&B
- Raises:
CheckPointNotFoundError – If the checkpoint file is not found.
- _load_state_dict_from_base_run_checkpoint(version: str = 'latest') dict [source]#
Load the experiment state from a base run checkpoint.
- Parameters:
version (str, default="latest") – The version of the checkpoint to load. If “latest”, the latest checkpoint is loaded. Otherwise, the version should be a string representing the version of the checkpoint to load from W&B.
- Returns:
state_dict (dict) – The state as a dictionary, loaded from W&B
- Raises:
CheckPointNotFoundError – If the checkpoint file is not found.
- _pretrain_agents()[source]#
Pretrain the agent bodies in isolation.
This just uses the SoloAgentTrainer class.
- _run_test_loop(iteration_context: IterationContext)[source]#
Run the test loop.
- Parameters:
iteration_context (IterationContext) – The context used during testing. This controls the progress bar.
- _run_train_loop(iteration_context: IterationContext)[source]#
Run the training loop.
- Parameters:
iteration_context (IterationContext) – The context used during training. This controls the progress bar.
- _train_on_replay_buffer() TensorDict [source]#
Train the agents on data in the replay buffer.
- Returns:
mean_loss_vals (TensorDict) – The mean loss values over the training iterations.
- classmethod get_checkpoint_base_dir_from_run_id(run_id: str) Path [source]#
Get the base directory for a checkpoint from a run ID.
- Parameters:
run_id (str) – The run ID.
- Returns:
checkpoint_base_dir (Path) – The path to the base directory for the checkpoint.
- get_total_num_iterations() int [source]#
Get the total number of iterations that the trainer will run for.
This is the sum of the number of iterations declared by methods decorated with
attach_progress_bar
.- Returns:
total_iterations (int) – The total number of iterations.
- load_and_set_state_from_checkpoint(from_base_run: bool = False, version: str = 'latest')[source]#
Load and set the experiment state from a checkpoint, if available.
- Parameters:
from_base_run (bool, default=False) – Whether to load the checkpoint artifact from the base run. By default, the artifact is loaded form the current active run.
version (str, default="latest") – The version of the checkpoint to load. If “latest”, the latest checkpoint is loaded. Otherwise, the version should be a string representing the version of the checkpoint to load from W&B. This must be “latest” if
from_base_run
is False.
- Raises:
CheckPointNotFoundError – If the checkpoint file is not found.