nip.trainers.rl_pure_text_base.PureTextRlTrainer#
- class nip.trainers.rl_pure_text_base.PureTextRlTrainer(hyper_params: HyperParameters, scenario_instance: ScenarioInstance, settings: ExperimentSettings)[source]#
- Base class for RL trainers for text-based environments that only use APIs. - Parameters:
- hyper_params (HyperParameters) – The parameters of the experiment. 
- scenario_instance (ScenarioInstance) – The components of the experiment. 
- settings (ExperimentSettings) – The instance-specific settings of the experiment, like device, logging, etc. 
 
 - Methods Summary - __init__(hyper_params, scenario_instance, ...)- _add_file_to_wandb_artifact(artifact_name, ...)- Add a file to a W&B artifact, creating the artifact if it doesn't exist. - Check if the test loop should be run in the current iteration. - _extract_transcripts_and_prompts(rollouts, ...)- Extract the raw and processed transcripts, and prompts, from the rollouts. - _get_fine_tune_job_name(shared_model_group)- Get a name for the fine-tune job for the given shared model group. - Get the message to log at the beginning of each iteration. - _get_log_stats(rollouts, *[, train])- Get the statistics to log for the given rollouts. - Get metadata relevant to the current experiment, for saving the checkpoint. - Get the proportion of rollouts to replace the guess with the true label. - Initialise the state of the experiment. - _load_rollouts(iterations)- Load the rollouts from the checkpoint directory. - Load the experiment state from a checkpoint for the active run, if available. - Load the experiment state from a base run checkpoint. - Get the previous iterations which are combinable with the current iteration. - _sample_rollouts(environment, iteration[, ...])- Sample rollouts in the environment. - Sample rollouts for a single environment. - _save_rollouts(rollouts, environment[, ...])- Save the rollouts to the checkpoint directory. - Training stage: await the completion of the fine-tune jobs. - _stage_create_fine_tune_jobs(rollouts[, ...])- Training stage: create fine-tune jobs for each agent. - _stage_log_stats(rollouts)- Training stage: log the statistics of the rollouts. - Training stage: run the test loop. - Training stage: sample rollouts from the training environment. - _train()- Run the actual training loop implementation, which is asynchronous. - Get the base directory for a checkpoint from a run ID. - Get the total number of iterations that the trainer will run for. - Load and set the experiment state from a checkpoint, if available. - run_analysers(analysers, model_name, *[, ...])- Run the given analysers on the rollouts of the experiment. - save_checkpoint([log])- Save the state of the experiment to a checkpoint. - train()- Train the agents in the environment. - Attributes - agent_names- The names of the agents in the scenario. - agent_wholes- The 'whole' part of each agent. - checkpoint_analysis_dir- The directory to save the rollout analysis to. - checkpoint_base_dir- The path to the directory containing the checkpoint. - checkpoint_metadata_path- The path to the checkpoint metadata file. - checkpoint_params_path- The path to the parameters file for the checkpoint. - checkpoint_rollouts_dir- The directory to save the rollouts to. - checkpoint_state_path- The path to the checkpoint state file. - combined_agent- The agents combined into a single operator. - max_message_rounds- The maximum number of message rounds in the protocol. - num_agents- The number of agents in the scenario. - processed_transcripts_dir- The directory to save the processed transcripts to. - prompts_dir- The directory to save the prompts to. - protocol_handler- The protocol handler for the experiment. - raw_transcripts_dir- The directory to save the raw transcripts to. - shared_model_groups- The agents grouped by having a shared model. - state- The state of the experiment. - test_environment- The test environment. - train_environment- The training environment. - trainer_type- The type of trainer this is. - Methods - __init__(hyper_params: HyperParameters, scenario_instance: ScenarioInstance, settings: ExperimentSettings)[source]#
 - _add_file_to_wandb_artifact(artifact_name: str, artifact_type: str, file_path: Path)[source]#
- Add a file to a W&B artifact, creating the artifact if it doesn’t exist. - If the artifact already exists, we add the file to the existing artifact, creating a new version. 
 - _check_if_run_test_loop() bool[source]#
- Check if the test loop should be run in the current iteration. - Returns:
- run_test_loop (bool) – Whether the test loop should be run. 
 
 - _extract_transcripts_and_prompts(rollouts: NestedArrayDict, environment: PureTextEnvironment) tuple[list[dict], list[dict], list[list[dict[str, list[PromptMessage]]]]][source]#
- Extract the raw and processed transcripts, and prompts, from the rollouts. - The raw transcript is the sequence of outputs generated by the models, per agent, while the processed transcript is the result of processing these and extracting the message per channel. - Note that in the raw transcripts the messages are per agent, while in the processed transcripts the messages are per channel. - The transcripts have variable length, where if a round has no messages from any agent, we declare that the end of the transcript. - Parameters:
- rollouts (NestedArrayDict) – - The rollouts to extract the transcripts from. A NestedArrayDict with keys: - (“agents”, “message”) (batch round agent channel) : The processed message sent by each agent to each channel in each timestep. 
- (“agents”, “raw_message”) (batch round agent) : The raw message generated by each model in each timestep. 
- (“agents”, “prompt”) (batch round agent message field) : The prompt used by to generate the message for each agent in each timestep. 
- (“agents”, “decision”) (batch round agent) : The decision made by each agent in each timestep. 
- (“agents”, “continuous_decision”) (batch round agent) : A float version of the decision made by each agent at each timestep, which is a value between -1 and 1. 
- (“agents”, “raw_decision”) (batch round agent) : The raw decision text sent by each agent in each timestep. 
- (“agents”, “reward”) (batch round agent) : The reward received by each agent in each timestep. 
 - The nested array dict also contains keys which specify the datapoint for each rollout, as extracted by - environment.get_datapoint_from_env_state_as_dict.
- environment (PureTextEnvironment) – The environment the rollouts were sampled in. 
 
- Returns:
- raw_transcripts (list[dict]) – The raw transcripts. This is a list of transcripts, where each transcript is dictionary containing meta data and a “transcript” key. The value at “transcript” is a list of dictionaries whose keys are the agent names and values are the messages generated by the agents. 
- processed_transcripts (list[dict]) – The processed transcripts. This is a list of transcripts, where each transcript is dictionary containing meta data and a “transcript” key. The value at “transcript” is a list of dictionaries whose keys are - f"{active_agent_name}@{channel_name}"and values are the messages in each channel.
- prompts (list[list[dict[str, list[PromptMessage]]]]) – The prompts used to generate the messages at each timestep. This is a list containing for each batch item a list of dictionaries, one for each round. Each dictionary has the agent names as keys and the prompts used by the agents the as values. The prompts are a list of dictionaries, whose type is specified by the - PromptMessageclass.
 
 
 - _get_fine_tune_job_name(shared_model_group: PureTextSharedModelGroup) str[source]#
- Get a name for the fine-tune job for the given shared model group. - This name is generated from the run id, the iteration number, and the shared model group name, and is used make the job more easily identifiable. - Parameters:
- shared_model_group (PureTextSharedModelGroup) – The shared model group to create the fine-tune job for. 
- Returns:
- job_name (str) – The name of the fine-tune job. 
 
 - _get_iteration_begin_message() str[source]#
- Get the message to log at the beginning of each iteration. - Returns:
- message (str) – The message to log at the beginning of each iteration. 
 
 - _get_log_stats(rollouts: NestedArrayDict, *, train=True) dict[source]#
- Get the statistics to log for the given rollouts. - Parameters:
- rollouts (NestedArrayDict) – The rollouts to get the statistics for. 
- train (bool, default=True) – Whether the rollouts are from the training environment. 
 
- Returns:
- stats (dict) – The statistics to log. 
 
 - _get_metadata() dict[source]#
- Get metadata relevant to the current experiment, for saving the checkpoint. - Returns:
- metadata (dict) – The metadata to record 
 
 - _get_verifier_guess_replacement_proportion(iteration: int) float[source]#
- Get the proportion of rollouts to replace the guess with the true label. - For this proportion of the sampled rollouts, we replace the verifier guess with either “Decision: accept” or “Decision: reject” based on the true label. - This value can be annealed over the course of the training. - Parameters:
- iteration (int) – The current iteration number. 
- Returns:
- proportion (float) – The proportion of rollouts where we replace the guess with the true label. 
- Raises:
- ValueError – If the annealing type is invalid. 
 
 - _initialise_state()[source]#
- Initialise the state of the experiment. - This method should be implemented by subclasses to initialise the state of the experiment. This is called at the beginning of training when starting from scratch (i.e. not restoring from a checkpoint). 
 - _load_rollouts(iterations: int | Iterable[int]) NestedArrayDict[source]#
- Load the rollouts from the checkpoint directory. 
 - _load_state_dict_from_active_run_checkpoint() dict[source]#
- Load the experiment state from a checkpoint for the active run, if available. - Returns:
- state_dict (dict) – The state as a dictionary, loaded from W&B 
- Raises:
- CheckPointNotFoundError – If the checkpoint file is not found. 
 
 - _load_state_dict_from_base_run_checkpoint(version: str = 'latest') dict[source]#
- Load the experiment state from a base run checkpoint. - Parameters:
- version (str, default="latest") – The version of the checkpoint to load. If “latest”, the latest checkpoint is loaded. Otherwise, the version should be a string representing the version of the checkpoint to load from W&B. 
- Returns:
- state_dict (dict) – The state as a dictionary, loaded from W&B 
- Raises:
- CheckPointNotFoundError – If the checkpoint file is not found. 
 
 - _previous_compatible_iterations() Iterable[int][source]#
- Get the previous iterations which are combinable with the current iteration. - The method is used when combining rollouts from different iterations, and returns an iterable of the previous iteration numbers which are able to be combined with the current iteration. - Returns:
- previous_iterations (Iterable[int]) – The previous iterations which are combinable with the current iteration. 
 
 - async _sample_rollouts(environment: PureTextEnvironment, iteration: int | Literal['test'], use_tqdm: bool = False, tqdm_desc: str = 'Sampling rollouts') NestedArrayDict[source]#
- Sample rollouts in the environment. - We sample - environment.num_envsrollouts from the environment. A rollout is a sequence of length- max_message_roundsof states in the environment. The sampled rollout nested array dict thus has shape (num_envs, max_message_rounds).- Parameters:
- environment (PureTextEnvironment) – The environment to sample rollouts in. 
- iteration (int | Literal["test"]) – The iteration number, or “test” if the rollouts are from the test set. 
- use_tqdm (bool) – Whether to create a tqdm progress bar for the rollouts. 
- tqdm_desc (str) – The description to use for the tqdm progress bar. 
 
- Returns:
- rollouts (NestedArrayDict) – The rollouts in the environment. Has batch size (num_envs, max_message_rounds) 
 
 - async _sample_rollouts_for_single_environment(environment: PureTextEnvironment, data_batch: NestedArrayDict | None = None) list[NestedArrayDict][source]#
- Sample rollouts for a single environment. - A single environment is associated with a single datapoint. This method samples rollouts from it. It is intended that subclasses are able reimplement this if they need to sample rollouts in a different way. - In this default implementation, we sample a single rollout by stepping the environment until it is done, and then padding the rollout with zero states up to the maximum number of message rounds. - Parameters:
- environment (PureTextEnvironment) – The environment to sample rollouts in. 
- data_batch (NestedArrayDict, optional) – The data batch to use for the rollout. If None, the data batch will be sampled from the dataset. 
 
- Returns:
- list[NestedArrayDict] – The a single-element list containing the rollout in the environment. Has batch size (max_message_rounds, ) 
 
 - _save_rollouts(rollouts: NestedArrayDict, environment: PureTextEnvironment, iteration: int | None = None)[source]#
- Save the rollouts to the checkpoint directory. - Parameters:
- rollouts (NestedArrayDict) – The rollouts to save. 
- environment (PureTextEnvironment) – The environment the rollouts were sampled in. 
- iteration (int, optional) – The iteration number. If not provided, the current iteration number is used. 
 
 
 - async _stage_await_fine_tune_jobs()[source]#
- Training stage: await the completion of the fine-tune jobs. - Raises:
- ExceptionGroup[FineTuneJobError] – If any of the fine-tune jobs fail or are cancelled. Note that since we use a task group to await the fine-tune jobs, the exceptions are raised as an - ExceptionGroup. This can be caught using an- except*statement. If- exception_groupis caught, then- exception_group.exceptionswill contain the individual exceptions for each fine-tune job that failed.
 - Example - >>> try: >>> await trainer._stage_await_fine_tune_jobs() >>> except* FineTuneJobError as exception_group: >>> for exception in exception_group.exceptions: >>> logger.error(exception) 
 - abstract async _stage_create_fine_tune_jobs(rollouts: NestedArrayDict, only_failed: bool = False)[source]#
- Training stage: create fine-tune jobs for each agent. - Parameters:
- rollouts (NestedArrayDict, optional) – The rollouts sampled in this iteration. 
- only_failed (bool, default=False) – Whether to only create fine-tune jobs for shared model groups whose previous fine-tune job failed or was cancelled. If False, fine-tune jobs are created for all shared model groups. 
 
 
 - _stage_log_stats(rollouts: NestedArrayDict)[source]#
- Training stage: log the statistics of the rollouts. - Parameters:
- rollouts (NestedArrayDict) – The rollouts sampled in this iteration. 
 
 - async _stage_sample_rollouts() NestedArrayDict[source]#
- Training stage: sample rollouts from the training environment. - Returns:
- rollouts (NestedArrayDict) – The sampled rollouts. 
 
 - classmethod get_checkpoint_base_dir_from_run_id(run_id: str) Path[source]#
- Get the base directory for a checkpoint from a run ID. - Parameters:
- run_id (str) – The run ID. 
- Returns:
- checkpoint_base_dir (Path) – The path to the base directory for the checkpoint. 
 
 - get_total_num_iterations() int[source]#
- Get the total number of iterations that the trainer will run for. - This is the sum of the number of iterations declared by methods decorated with - attach_progress_bar.- Returns:
- total_iterations (int) – The total number of iterations. 
 
 - load_and_set_state_from_checkpoint(from_base_run: bool = False, version: str = 'latest')[source]#
- Load and set the experiment state from a checkpoint, if available. - Parameters:
- from_base_run (bool, default=False) – Whether to load the checkpoint artifact from the base run. By default, the artifact is loaded form the current active run. 
- version (str, default="latest") – The version of the checkpoint to load. If “latest”, the latest checkpoint is loaded. Otherwise, the version should be a string representing the version of the checkpoint to load from W&B. This must be “latest” if - from_base_runis False.
 
- Raises:
- CheckPointNotFoundError – If the checkpoint file is not found. 
 
 - run_analysers(analysers: list[str | type[PureTextRolloutAnalyser]], model_name: str, *, overwrite=False, use_tqdm=True, dry_run=False)[source]#
- Run the given analysers on the rollouts of the experiment. - This method can only be called after the experiment has finished. - Parameters:
- analysers (list[str | type[PureTextRolloutAnalyser]]) – The analysers to run. Either the name of the analyser or the analyser class itself. 
- model_name (str) – The name of the model to use for the analysis. 
- overwrite (bool, default=False) – Whether to overwrite the existing analysis files, if they exist. 
- use_tqdm (bool, default=True) – Whether create a progress bar for the analysis. 
- dry_run (bool, default=False) – Whether to do a dry run using a dummy API, not saving the results. 
 
 
 - save_checkpoint(log: bool = True)[source]#
- Save the state of the experiment to a checkpoint. - Parameters:
- log (bool, default=True) – Whether to log the checkpointing. 
 
 - train()[source]#
- Train the agents in the environment. - Runs the training loop for the specified number of iterations. The training loop consists of the following stages: - Sample rollouts from the training environment. 
- Log the statistics of the rollouts. 
- Run the test loop during training. 
- Create fine-tune jobs for each agent. 
- Await the completion of the fine-tune jobs. 
 - The training loop can be resumed from a previous checkpoint. If the training loop is resumed, the state of the experiment is loaded from the checkpoint, and the training loop is resumed from the last stage.