nip.trainers.solo_agent.SoloAgentTrainer#
- class nip.trainers.solo_agent.SoloAgentTrainer(hyper_params: HyperParameters, scenario_instance: ScenarioInstance, settings: ExperimentSettings)[source]#
Trainer for training tensordict agents in isolation.
- Parameters:
hyper_params (HyperParameters) – The parameters of the experiment.
scenario_instance (ComponentHolder) – The components of the experiment.
settings (ExperimentSettings) – The instance-specific settings of the experiment, like device, logging, etc.
Methods Summary
__init__
(hyper_params, scenario_instance, ...)_add_file_to_wandb_artifact
(artifact_name, ...)Add a file to a W&B artifact, creating the artifact if it doesn't exist.
_build_test_context
(stack)Build the context manager ExitStack for testing.
_build_train_context
(stack)Build the context manager ExitStack for training.
Get metadata relevant to the current experiment, for saving the checkpoint.
Initialise the state of the experiment.
Load the experiment state from a checkpoint for the active run, if available.
Load the experiment state from a base run checkpoint.
_run_test_loop
(test_dataset, agents_params, ...)Run the testing loop.
_run_train_loop
(train_dataset, ...)Run the training loop.
Get the base directory for a checkpoint from a run ID.
Get the total number of iterations that the trainer will run for.
Load and set the experiment state from a checkpoint, if available.
save_checkpoint
([log])Save the state of the experiment to a checkpoint.
train
([as_pretraining])Train the agents.
Attributes
agent_names
The names of the agents in the scenario.
checkpoint_base_dir
The path to the directory containing the checkpoint.
checkpoint_metadata_path
The path to the checkpoint metadata file.
checkpoint_params_path
The path to the parameters file for the checkpoint.
checkpoint_state_path
The path to the checkpoint state file.
max_message_rounds
The maximum number of message rounds in the protocol.
num_agents
The number of agents in the scenario.
protocol_handler
The protocol handler for the experiment.
Methods
- __init__(hyper_params: HyperParameters, scenario_instance: ScenarioInstance, settings: ExperimentSettings)[source]#
- _add_file_to_wandb_artifact(artifact_name: str, artifact_type: str, file_path: Path)[source]#
Add a file to a W&B artifact, creating the artifact if it doesn’t exist.
If the artifact already exists, we add the file to the existing artifact, creating a new version.
- _build_test_context(stack: ExitStack) list[ContextManager] [source]#
Build the context manager ExitStack for testing.
Takes as input an ExitStack and adds the appropriate context managers to it, then returns the context managers.
- Parameters:
stack (ExitStack) – The ExitStack to add the context managers to. Note that this is modified in-place.
- Returns:
context_managers (list[ContextManager]) – The target context managers to be used in the testing loop.
- _build_train_context(stack: ExitStack) list[ContextManager] [source]#
Build the context manager ExitStack for training.
Takes as input an ExitStack and adds the appropriate context managers to it, then returns the context managers.
- Parameters:
stack (ExitStack) – The ExitStack to add the context managers to. Note that this is modified in-place.
- Returns:
context_managers (list[ContextManager]) – The target context managers to be used in the training loop.
- _get_metadata() dict [source]#
Get metadata relevant to the current experiment, for saving the checkpoint.
- Returns:
metadata (dict) – The metadata to record
- _initialise_state()[source]#
Initialise the state of the experiment.
This method should be implemented by subclasses to initialise the state of the experiment. This is called at the beginning of training when starting from scratch (i.e. not restoring from a checkpoint).
- _load_state_dict_from_active_run_checkpoint() dict [source]#
Load the experiment state from a checkpoint for the active run, if available.
- Returns:
state_dict (dict) – The state as a dictionary, loaded from W&B
- Raises:
CheckPointNotFoundError – If the checkpoint file is not found.
- _load_state_dict_from_base_run_checkpoint(version: str = 'latest') dict [source]#
Load the experiment state from a base run checkpoint.
- Parameters:
version (str, default="latest") – The version of the checkpoint to load. If “latest”, the latest checkpoint is loaded. Otherwise, the version should be a string representing the version of the checkpoint to load from W&B.
- Returns:
state_dict (dict) – The state as a dictionary, loaded from W&B
- Raises:
CheckPointNotFoundError – If the checkpoint file is not found.
- _run_test_loop(test_dataset: Dataset, agents_params: AgentsParameters, agents: dict[str, Agent], agent_models: dict[str, TensorDictSequential], as_pretraining: bool, logger: Logger)[source]#
Run the testing loop.
- Parameters:
test_dataset (Dataset) – The dataset to test on.
agents_params (AgentsParameters) – The parameters of the agents.
agents (dict[str, Agent]) – A dictionary of the classes which hold the agent components.
agent_models (dict[str, TensorDictSequential]) – A dictionary of the actual models we’re testing.
as_pretraining (bool) – Whether we’re testing the agents as a pretraining step.
logger (logging.Logger) – The logger to use.
- _run_train_loop(train_dataset: Dataset, agents_params: AgentsParameters, agents: dict[str, Agent], agent_models: dict[str, TensorDictSequential], as_pretraining: bool, torch_generator: Generator, iteration_context: IterationContext)[source]#
Run the training loop.
- Parameters:
train_dataset (Dataset) – The dataset to train on.
agents_params (AgentsParameters) – The parameters of the agents.
agents (dict[str, Agent]) – A dictionary of the classes which hold the agent components.
agent_models (dict[str, TensorDictSequential]) – A dictionary of the actual models we’re training.
as_pretraining (bool) – Whether we’re training the agents as a pretraining step.
torch_generator (torch.Generator) – The random number generator to use.
iteration_context (IterationContext) – The context to use for the training loop, which handles the progress bar.
- classmethod get_checkpoint_base_dir_from_run_id(run_id: str) Path [source]#
Get the base directory for a checkpoint from a run ID.
- Parameters:
run_id (str) – The run ID.
- Returns:
checkpoint_base_dir (Path) – The path to the base directory for the checkpoint.
- get_total_num_iterations() int [source]#
Get the total number of iterations that the trainer will run for.
This is the sum of the number of iterations declared by methods decorated with
attach_progress_bar
.- Returns:
total_iterations (int) – The total number of iterations.
- load_and_set_state_from_checkpoint(from_base_run: bool = False, version: str = 'latest')[source]#
Load and set the experiment state from a checkpoint, if available.
- Parameters:
from_base_run (bool, default=False) – Whether to load the checkpoint artifact from the base run. By default, the artifact is loaded form the current active run.
version (str, default="latest") – The version of the checkpoint to load. If “latest”, the latest checkpoint is loaded. Otherwise, the version should be a string representing the version of the checkpoint to load from W&B. This must be “latest” if
from_base_run
is False.
- Raises:
CheckPointNotFoundError – If the checkpoint file is not found.