nip.trainers.rl_tensordict_base.TensorDictRlTrainer#
- class nip.trainers.rl_tensordict_base.TensorDictRlTrainer(hyper_params: HyperParameters, scenario_instance: ScenarioInstance, settings: ExperimentSettings)[source]#
- Base class for all reinforcement learning trainers which use tensordicts. - This class implements a standard RL training loop using TorchRL. To subclass it, implement the _get_loss_module_and_gae method, which should return the loss module and, optionally, the generalized advantage estimator. - Parameters:
- hyper_params (HyperParameters) – The parameters of the experiment. 
- scenario_instance (ScenarioInstance) – The components of the experiment. 
- device (TorchDevice) – The device to use for training. 
 
 - Methods Summary - __init__(hyper_params, scenario_instance, ...)- _add_file_to_wandb_artifact(artifact_name, ...)- Add a file to a W&B artifact, creating the artifact if it doesn't exist. - Add observation normalization transforms to the environments. - Get the policy and value operators for the agents. - _build_test_context(stack)- Build the context manager ExitStack for testing. - _build_train_context(stack)- Build the context manager ExitStack for training. - Construct the data collectors, which generate rollouts from the environment. - _get_log_stats(rollouts[, mean_loss_vals, train])- Compute the statistics to log during training or testing. - Construct the loss module and the generalized advantage estimator. - Get metadata relevant to the current experiment, for saving the checkpoint. - _get_optimizer_and_param_freezer(loss_module)- Construct the optimizer for the loss module and the model parameter freezer. - _get_replay_buffer([transform])- Construct the replay buffer, which will store the rollouts. - Initialise the state of the trainer. - Load the experiment state from a checkpoint for the active run, if available. - Load the experiment state from a base run checkpoint. - Pretrain the agent bodies in isolation. - _run_test_loop(iteration_context)- Run the test loop. - _run_train_loop(iteration_context)- Run the training loop. - Run generic RL training and test loops. - Train the agents on data in the replay buffer. - Get the base directory for a checkpoint from a run ID. - Get the total number of iterations that the trainer will run for. - Load and set the experiment state from a checkpoint, if available. - save_checkpoint([log])- Save the state of the experiment to a checkpoint. - train()- Train the agents. - Attributes - PRETRAINED_MODEL_CACHE_PARAM_KEYS- agent_names- The names of the agents in the scenario. - checkpoint_base_dir- The path to the directory containing the checkpoint. - checkpoint_metadata_path- The path to the checkpoint metadata file. - checkpoint_params_path- The path to the parameters file for the checkpoint. - checkpoint_state_path- The path to the checkpoint state file. - max_message_rounds- The maximum number of message rounds in the protocol. - num_agents- The number of agents in the scenario. - protocol_handler- The protocol handler for the experiment. - trainer_type- The type of trainer this is. - policy_operator- value_operator- Methods - __init__(hyper_params: HyperParameters, scenario_instance: ScenarioInstance, settings: ExperimentSettings)[source]#
 - _add_file_to_wandb_artifact(artifact_name: str, artifact_type: str, file_path: Path)[source]#
- Add a file to a W&B artifact, creating the artifact if it doesn’t exist. - If the artifact already exists, we add the file to the existing artifact, creating a new version. 
 - _add_normalization_transforms()[source]#
- Add observation normalization transforms to the environments. 
 - _build_test_context(stack: ExitStack) list[ContextManager][source]#
- Build the context manager ExitStack for testing. - Takes as input an ExitStack and adds the appropriate context managers to it, then returns the context managers. - Parameters:
- stack (ExitStack) – The ExitStack to add the context managers to. Note that this is modified in-place. 
- Returns:
- context_managers (list[ContextManager]) – The target context managers to be used in the testing loop. 
 
 - _build_train_context(stack: ExitStack) list[ContextManager][source]#
- Build the context manager ExitStack for training. - Takes as input an ExitStack and adds the appropriate context managers to it, then returns the context managers. - Parameters:
- stack (ExitStack) – The ExitStack to add the context managers to. Note that this is modified in-place. 
- Returns:
- context_managers (list[ContextManager]) – The target context managers to be used in the training loop. 
 
 - _get_data_collectors() tuple[SyncDataCollector, SyncDataCollector][source]#
- Construct the data collectors, which generate rollouts from the environment. - Constructs a collector for both the train and the test environment. - Returns:
- train_collector (SyncDataCollector) – The train data collector. 
- test_collector (SyncDataCollector) – The test data collector. 
 
 
 - _get_log_stats(rollouts: TensorDictBase, mean_loss_vals: TensorDictBase | None = None, *, train=True) dict[str, float][source]#
- Compute the statistics to log during training or testing. - Parameters:
- rollouts (TensorDict) – The data sampled from the data collector. 
- mean_loss_vals (TensorDict, optional) – The average loss values. 
- train (bool, default=True) – Whether the statistics are for training or testing. 
 
- Returns:
- log_stats (dict[str, float]) – The statistics to log. 
 
 - abstract _get_loss_module_and_gae() tuple[Objective, GAE | None][source]#
- Construct the loss module and the generalized advantage estimator. - Returns:
- loss_module (Objective) – The loss module. 
- gae (GAE | None) – The generalized advantage estimator, or None if the loss module doesn’t use one. 
 
 
 - _get_metadata() dict[source]#
- Get metadata relevant to the current experiment, for saving the checkpoint. - Returns:
- metadata (dict) – The metadata to record 
 
 - _get_optimizer_and_param_freezer(loss_module: Objective) tuple[Adam, ParamGroupFreezer][source]#
- Construct the optimizer for the loss module and the model parameter freezer. - Parameters:
- loss_module (Objective) – The loss module. 
- Returns:
- optimizer (torch.optim.Optimizer) – The optimizer. 
- param_group_freezer (ParamGroupFreezer) – The parameter dictionaries for each agent. 
 
 
 - _get_replay_buffer(transform: Transform | None = None) ReplayBuffer[source]#
- Construct the replay buffer, which will store the rollouts. - Parameters:
- transform (Transform, optional) – The transform to apply to the data before storing it in the replay buffer. 
- Returns:
- ReplayBuffer – The replay buffer. 
 
 - _load_state_dict_from_active_run_checkpoint() dict[source]#
- Load the experiment state from a checkpoint for the active run, if available. - Returns:
- state_dict (dict) – The state as a dictionary, loaded from W&B 
- Raises:
- CheckPointNotFoundError – If the checkpoint file is not found. 
 
 - _load_state_dict_from_base_run_checkpoint(version: str = 'latest') dict[source]#
- Load the experiment state from a base run checkpoint. - Parameters:
- version (str, default="latest") – The version of the checkpoint to load. If “latest”, the latest checkpoint is loaded. Otherwise, the version should be a string representing the version of the checkpoint to load from W&B. 
- Returns:
- state_dict (dict) – The state as a dictionary, loaded from W&B 
- Raises:
- CheckPointNotFoundError – If the checkpoint file is not found. 
 
 - _pretrain_agents()[source]#
- Pretrain the agent bodies in isolation. - This just uses the SoloAgentTrainer class. 
 - _run_test_loop(iteration_context: IterationContext)[source]#
- Run the test loop. - Parameters:
- iteration_context (IterationContext) – The context used during testing. This controls the progress bar. 
 
 - _run_train_loop(iteration_context: IterationContext)[source]#
- Run the training loop. - Parameters:
- iteration_context (IterationContext) – The context used during training. This controls the progress bar. 
 
 - _train_on_replay_buffer() TensorDict[source]#
- Train the agents on data in the replay buffer. - Returns:
- mean_loss_vals (TensorDict) – The mean loss values over the training iterations. 
 
 - classmethod get_checkpoint_base_dir_from_run_id(run_id: str) Path[source]#
- Get the base directory for a checkpoint from a run ID. - Parameters:
- run_id (str) – The run ID. 
- Returns:
- checkpoint_base_dir (Path) – The path to the base directory for the checkpoint. 
 
 - get_total_num_iterations() int[source]#
- Get the total number of iterations that the trainer will run for. - This is the sum of the number of iterations declared by methods decorated with - attach_progress_bar.- Returns:
- total_iterations (int) – The total number of iterations. 
 
 - load_and_set_state_from_checkpoint(from_base_run: bool = False, version: str = 'latest')[source]#
- Load and set the experiment state from a checkpoint, if available. - Parameters:
- from_base_run (bool, default=False) – Whether to load the checkpoint artifact from the base run. By default, the artifact is loaded form the current active run. 
- version (str, default="latest") – The version of the checkpoint to load. If “latest”, the latest checkpoint is loaded. Otherwise, the version should be a string representing the version of the checkpoint to load from W&B. This must be “latest” if - from_base_runis False.
 
- Raises:
- CheckPointNotFoundError – If the checkpoint file is not found.