nip.trainers.malt_pure_text.PureTextMaltTrainer#
- class nip.trainers.malt_pure_text.PureTextMaltTrainer(hyper_params: HyperParameters, scenario_instance: ScenarioInstance, settings: ExperimentSettings)[source]#
Multi-Agent LLM Training (MALT) for text-based environments that only use APIs.
In the MALT protocol [MSD+24], we sample multiple responses per timestep from the agents. This means that for each datapoint we have a tree of responses. For each agent
A
, at each decision point forA
we look at the expected reward forA
for each of the responses. We then select preference pairs of responses from these and train using Direct Preference Optimization [RSM+23]. The way pairs are selected is determined by thehyper_params.pure_text_malt.pair_selection_method
parameter, which can be one of the following:“positive_negative”: Selects a response where the agent’s expected reward is above a certain threshold (by default the reward mid-point) and a response where the agent’s expected reward is below this threshold.
“interval”: Selects a pair of responses where the difference in expected reward is above a certain threshold. This threshold is computed as
interval_threshold_proportion
times the difference between the maximum and minimum possible reward for the agent.
It is also possible do some rounds of Expert Iteration (EI) before doing MALT. The
PureTextMaltTrainer
class inherits from thePureTextEiTrainer
class, which implements the EI protocol, and allows running EI for a number of iterations specified by thehyper_params.pure_text_malt.num_initial_ei_iterations
parameter.- Parameters:
hyper_params (HyperParameters) – The parameters of the experiment.
scenario_instance (ScenarioInstance) – The components of the experiment.
settings (ExperimentSettings) – The instance-specific settings of the experiment, like device, logging, etc.
Methods Summary
__init__
(hyper_params, scenario_instance, ...)_add_file_to_wandb_artifact
(artifact_name, ...)Add a file to a W&B artifact, creating the artifact if it doesn't exist.
Check if the test loop should be run in the current iteration.
Compute the expected reward for each agent at each node of the tree.
_extract_transcripts_and_prompts
(rollouts, ...)Extract the raw and processed transcripts, and prompts, from the rollouts.
_generate_response_tree
(environment[, ...])Generate the tree of responses for a single datapoint.
_get_fine_tune_job_name
(shared_model_group)Get a name for the fine-tune job for the given shared model group.
Get the message to log at the beginning of each iteration.
_get_log_stats
(rollouts, *[, train])Get the statistics to log for the given rollouts.
Get metadata relevant to the current experiment, for saving the checkpoint.
_get_unique_timesteps
(rollouts)Break the rollouts into timesteps, and remove duplicate nodes.
Get the proportion of rollouts to replace the guess with the true label.
Initialise the state of the experiment.
_load_rollouts
(iterations)Load the rollouts from the checkpoint directory.
Load the experiment state from a checkpoint for the active run, if available.
Load the experiment state from a base run checkpoint.
Get the previous iterations which are combinable with the current iteration.
Sample positive and negative examples for each node in the tree of responses.
_sample_rollouts
(environment, iteration[, ...])Sample rollouts in the environment.
Sample rollouts for a single environment.
_save_rollouts
(rollouts, environment[, ...])Save the rollouts to the checkpoint directory.
_select_rollouts_for_fine_tuning
(rollouts, ...)Select rollouts to fine-tune on, based on the reward.
Training stage: await the completion of the fine-tune jobs.
_stage_create_fine_tune_jobs
(rollouts[, ...])Training stage: create fine-tune jobs for each agent.
_stage_log_stats
(rollouts)Training stage: log the statistics of the rollouts.
Training stage: run the test loop.
Training stage: sample rollouts from the training environment.
_train
()Run the actual training loop implementation, which is asynchronous.
Get the base directory for a checkpoint from a run ID.
Get the total number of iterations that the trainer will run for.
Load and set the experiment state from a checkpoint, if available.
run_analysers
(analysers, model_name, *[, ...])Run the given analysers on the rollouts of the experiment.
save_checkpoint
([log])Save the state of the experiment to a checkpoint.
train
()Train the agents in the environment.
Attributes
agent_names
The names of the agents in the scenario.
agent_wholes
The 'whole' part of each agent.
checkpoint_analysis_dir
The directory to save the rollout analysis to.
checkpoint_base_dir
The path to the directory containing the checkpoint.
checkpoint_metadata_path
The path to the checkpoint metadata file.
checkpoint_params_path
The path to the parameters file for the checkpoint.
checkpoint_rollouts_dir
The directory to save the rollouts to.
checkpoint_state_path
The path to the checkpoint state file.
combined_agent
The agents combined into a single operator.
max_message_rounds
The maximum number of message rounds in the protocol.
num_agents
The number of agents in the scenario.
processed_transcripts_dir
The directory to save the processed transcripts to.
prompts_dir
The directory to save the prompts to.
protocol_handler
The protocol handler for the experiment.
raw_transcripts_dir
The directory to save the raw transcripts to.
shared_model_groups
The agents grouped by having a shared model.
state
The state of the experiment.
test_environment
The test environment.
train_environment
The training environment.
trainer_type
The type of trainer this is.
Methods
- __init__(hyper_params: HyperParameters, scenario_instance: ScenarioInstance, settings: ExperimentSettings)[source]#
- _add_file_to_wandb_artifact(artifact_name: str, artifact_type: str, file_path: Path)[source]#
Add a file to a W&B artifact, creating the artifact if it doesn’t exist.
If the artifact already exists, we add the file to the existing artifact, creating a new version.
- _check_if_run_test_loop() bool [source]#
Check if the test loop should be run in the current iteration.
- Returns:
run_test_loop (bool) – Whether the test loop should be run.
- _compute_tree_expected_reward(partial_rollouts_by_level: list[list[_PartialRolloutNode]])[source]#
Compute the expected reward for each agent at each node of the tree.
The expected reward in the average reward that an agent receives over all branches passing through a node. This is stored in the
("agents", "expected_reward")
field of the rollouts, which are modified in-place.This is computed by summing up the total reward for all descendants, proceeding from the leaves to the root, and dividing by the number of branches passing through the node.
- Parameters:
partial_rollouts_by_level (list[list[_PartialRolloutNode]]) – The tree of responses, stratified by level. These are modified in-place, where we add the
("agents", "expected_reward")
field containing the expected reward for each agent at each node.
- _extract_transcripts_and_prompts(rollouts: NestedArrayDict, environment: PureTextEnvironment) tuple[list[dict], list[dict], list[list[dict[str, list[PromptMessage]]]]] [source]#
Extract the raw and processed transcripts, and prompts, from the rollouts.
The raw transcript is the sequence of outputs generated by the models, per agent, while the processed transcript is the result of processing these and extracting the message per channel.
Note that in the raw transcripts the messages are per agent, while in the processed transcripts the messages are per channel.
The transcripts have variable length, where if a round has no messages from any agent, we declare that the end of the transcript.
- Parameters:
rollouts (NestedArrayDict) –
The rollouts to extract the transcripts from. A NestedArrayDict with keys:
(“agents”, “message”) (batch round agent channel) : The processed message sent by each agent to each channel in each timestep.
(“agents”, “raw_message”) (batch round agent) : The raw message generated by each model in each timestep.
(“agents”, “prompt”) (batch round agent message field) : The prompt used by to generate the message for each agent in each timestep.
(“agents”, “decision”) (batch round agent) : The decision made by each agent in each timestep.
(“agents”, “continuous_decision”) (batch round agent) : A float version of the decision made by each agent at each timestep, which is a value between -1 and 1.
(“agents”, “raw_decision”) (batch round agent) : The raw decision text sent by each agent in each timestep.
(“agents”, “reward”) (batch round agent) : The reward received by each agent in each timestep.
The nested array dict also contains keys which specify the datapoint for each rollout, as extracted by
environment.get_datapoint_from_env_state_as_dict
.environment (PureTextEnvironment) – The environment the rollouts were sampled in.
- Returns:
raw_transcripts (list[dict]) – The raw transcripts. This is a list of transcripts, where each transcript is dictionary containing meta data and a “transcript” key. The value at “transcript” is a list of dictionaries whose keys are the agent names and values are the messages generated by the agents.
processed_transcripts (list[dict]) – The processed transcripts. This is a list of transcripts, where each transcript is dictionary containing meta data and a “transcript” key. The value at “transcript” is a list of dictionaries whose keys are
f"{active_agent_name}@{channel_name}"
and values are the messages in each channel.prompts (list[list[dict[str, list[PromptMessage]]]]) – The prompts used to generate the messages at each timestep. This is a list containing for each batch item a list of dictionaries, one for each round. Each dictionary has the agent names as keys and the prompts used by the agents the as values. The prompts are a list of dictionaries, whose type is specified by the
PromptMessage
class.
- async _generate_response_tree(environment: PureTextEnvironment, data_batch: NestedArrayDict | None = None) list[list[_PartialRolloutNode]] [source]#
Generate the tree of responses for a single datapoint.
This generates a tree of partial rollouts, where the children of each node are the one-step continuations of the node formed by generating multiple different responses for each active agent at that time step. At each step we sample
hyper_params.pure_text_malt.num_responses_per_timestep
responses.The output tree is stratified by the level in the tree, with the root node (empty partial rollout) at the first level. Note that in general, the tree will not be fully generated, because the environment may terminate before the maximum number of message rounds is reached.
- Parameters:
environment (PureTextEnvironment) – The environment to sample rollouts in.
data_batch (NestedArrayDict, optional) – The data batch to use for the rollout. If None, the data batch will be sampled from the dataset.
- Returns:
partial_rollouts_by_level (list[list[_PartialRolloutNode]]) – The tree of responses, stratified by level.
- _get_fine_tune_job_name(shared_model_group: PureTextSharedModelGroup) str [source]#
Get a name for the fine-tune job for the given shared model group.
This name is generated from the run id, the iteration number, and the shared model group name, and is used make the job more easily identifiable.
- Parameters:
shared_model_group (PureTextSharedModelGroup) – The shared model group to create the fine-tune job for.
- Returns:
job_name (str) – The name of the fine-tune job.
- _get_iteration_begin_message() str [source]#
Get the message to log at the beginning of each iteration.
- Returns:
message (str) – The message to log at the beginning of each iteration.
- _get_log_stats(rollouts: NestedArrayDict, *, train=True) dict [source]#
Get the statistics to log for the given rollouts.
This method extends the base class method to include the MALT-specific statistics.
- Parameters:
rollouts (NestedArrayDict) – The rollouts to get the statistics for.
train (bool, default=True) – Whether the rollouts are from the training environment.
- Returns:
stats (dict) – The statistics to log.
- _get_metadata() dict [source]#
Get metadata relevant to the current experiment, for saving the checkpoint.
- Returns:
metadata (dict) – The metadata to record
- _get_unique_timesteps(rollouts: NestedArrayDict) NestedArrayDict [source]#
Break the rollouts into timesteps, and remove duplicate nodes.
Each timestep is a unique node in the tree of responses.
- Parameters:
rollouts (NestedArrayDict) – The rollouts to get the timesteps for. Has batch size (batch round).
- Returns:
timesteps (NestedArrayDict) – The rollouts, broken into timesteps, with the duplicate nodes removed. Has batch size (timestep).
- _get_verifier_guess_replacement_proportion(iteration: int) float [source]#
Get the proportion of rollouts to replace the guess with the true label.
For this proportion of the sampled rollouts, we replace the verifier guess with either “Decision: accept” or “Decision: reject” based on the true label.
This value can be annealed over the course of the training.
- Parameters:
iteration (int) – The current iteration number.
- Returns:
proportion (float) – The proportion of rollouts where we replace the guess with the true label.
- Raises:
ValueError – If the annealing type is invalid.
- _initialise_state()[source]#
Initialise the state of the experiment.
This method should be implemented by subclasses to initialise the state of the experiment. This is called at the beginning of training when starting from scratch (i.e. not restoring from a checkpoint).
- _load_rollouts(iterations: int | Iterable[int]) NestedArrayDict [source]#
Load the rollouts from the checkpoint directory.
- _load_state_dict_from_active_run_checkpoint() dict [source]#
Load the experiment state from a checkpoint for the active run, if available.
- Returns:
state_dict (dict) – The state as a dictionary, loaded from W&B
- Raises:
CheckPointNotFoundError – If the checkpoint file is not found.
- _load_state_dict_from_base_run_checkpoint(version: str = 'latest') dict [source]#
Load the experiment state from a base run checkpoint.
- Parameters:
version (str, default="latest") – The version of the checkpoint to load. If “latest”, the latest checkpoint is loaded. Otherwise, the version should be a string representing the version of the checkpoint to load from W&B.
- Returns:
state_dict (dict) – The state as a dictionary, loaded from W&B
- Raises:
CheckPointNotFoundError – If the checkpoint file is not found.
- _previous_compatible_iterations() Iterable[int] [source]#
Get the previous iterations which are combinable with the current iteration.
The method is used when combining rollouts from different iterations, and returns an iterable of the previous iteration numbers which are able to be combined with the current iteration.
When doing initial EI iterations, on the iterations where we do MALT, we combine only the rollouts which also do MALT, not the ones which do EI. This is because the rollouts are not compatible, and we don’t want to mix them.
- Returns:
previous_iterations (Iterable[int]) – The previous iterations which are combinable with the current iteration.
- _sample_positive_and_negative_examples(partial_rollouts_by_level: list[list[_PartialRolloutNode]])[source]#
Sample positive and negative examples for each node in the tree of responses.
The way this is done depends on the
pair_selection_method
hyper-parameter, which can be one of the following:“positive_negative”: We look at each node and check if in its children there is a positive and a negative example.
“interval”: We look at each node and check if in its children there is a pair of nodes whose expected rewards differ by more than a certain threshold. This threshold is
interval_threshold_proportion
times the difference between the maximum and minimum possible reward for the agent.
If we find a valid pair we randomly sample one and set the
("agents", "is_pair_positive")
and("agents", "is_pair_negative")
fields to True in the positive and negative example, respectively.- Parameters:
partial_rollouts_by_level (list[list[_PartialRolloutNode]]) – The tree of responses, stratified by level. These are modified in-place, where we add
("agents", "is_pair_positive")
and("agents", "is_pair_negative")
fields to the rollouts.
- async _sample_rollouts(environment: PureTextEnvironment, iteration: int | Literal['test'], use_tqdm: bool = False, tqdm_desc: str = 'Sampling rollouts') NestedArrayDict [source]#
Sample rollouts in the environment.
We sample
environment.num_envs
rollouts from the environment. A rollout is a sequence of lengthmax_message_rounds
of states in the environment. The sampled rollout nested array dict thus has shape (num_envs, max_message_rounds).- Parameters:
environment (PureTextEnvironment) – The environment to sample rollouts in.
iteration (int | Literal["test"]) – The iteration number, or “test” if the rollouts are from the test set.
use_tqdm (bool) – Whether to create a tqdm progress bar for the rollouts.
tqdm_desc (str) – The description to use for the tqdm progress bar.
- Returns:
rollouts (NestedArrayDict) – The rollouts in the environment. Has batch size (num_envs, max_message_rounds)
- _sample_rollouts_for_single_environment(environment: PureTextEnvironment, data_batch: NestedArrayDict | None = None) list[NestedArrayDict] [source]#
Sample rollouts for a single environment.
A single environment is associated with a single datapoint. This method samples rollouts from it.
To implement the MALT training scheme, we need to sample multiple responses per timestep from the agents, and generate a tree of responses.
We also do additional processing and compute various statistics for each node in the tree of responses. It’s more efficient and easier to do this now rather than later, because we have access to the full tree structure. While it’s possible recover this later, it takes a bit of work because the rollouts are stored in arrays.
We compute the expected reward for each agent at each node of the tree by summing up the total reward for all descendants, proceeding from the leaves to the root, and dividing by the number of branches passing through the node. This is stored in the
("agents", "expected_reward")
field of the rollouts.We look at each node and check if in its children there a valid preference pair. If so, we randomly sample one and set the
("agents", "is_pair_positive")
and("agents", "is_pair_negative")
fields toTrue
for the positive and negative example, respectively.Each node in the response tree gets a unique ID, stored in
_node_id
which has shape(max_message_rounds, )
. This allows reconstructing the tree of responses later, if required, because if the same node ID appears in two different rollouts, then those points in the message history are the same. We also store the parent node ID in the_parent_node_id
field. The first timesteps of a rollout have a ‘pseudo-parent’ node ID. This is important because the tree may branch immediately at the first timestep.
Shapes
The following are the shapes of the additional fields added to each rollout.
(“agents”, “expected_reward”): “round agent”
(“agents”, “is_pair_positive”): “round agent”
(“agents”, “is_pair_negative”): “round agent”
“_node_id”: “round”
“_parent_node_id”: “round”
- Parameters:
environment (PureTextEnvironment) – The environment to sample rollouts in.
data_batch (NestedArrayDict, optional) – The data batch to use for the rollout. If None, the data batch will be sampled from the dataset.
- Returns:
sampled_rollouts = list[NestedArrayDict] – The list of sampled rollouts, each of which has batch size (max_message_rounds, ).
- _save_rollouts(rollouts: NestedArrayDict, environment: PureTextEnvironment, iteration: int | None = None)[source]#
Save the rollouts to the checkpoint directory.
- Parameters:
rollouts (NestedArrayDict) – The rollouts to save.
environment (PureTextEnvironment) – The environment the rollouts were sampled in.
iteration (int, optional) – The iteration number. If not provided, the current iteration number is used.
- _select_rollouts_for_fine_tuning(rollouts: NestedArrayDict, agent_name: str) NestedArrayDict [source]#
Select rollouts to fine-tune on, based on the reward.
- Parameters:
rollouts (NestedArrayDict) – The rollouts to select from.
agent_name (str) – The name of the agent for which to select the rollouts.
- Returns:
selected_rollouts (NestedArrayDict) – The selected rollouts.
- async _stage_await_fine_tune_jobs()[source]#
Training stage: await the completion of the fine-tune jobs.
- Raises:
ExceptionGroup[FineTuneJobError] – If any of the fine-tune jobs fail or are cancelled. Note that since we use a task group to await the fine-tune jobs, the exceptions are raised as an
ExceptionGroup
. This can be caught using anexcept*
statement. Ifexception_group
is caught, thenexception_group.exceptions
will contain the individual exceptions for each fine-tune job that failed.
Example
>>> try: >>> await trainer._stage_await_fine_tune_jobs() >>> except* FineTuneJobError as exception_group: >>> for exception in exception_group.exceptions: >>> logger.error(exception)
- _stage_create_fine_tune_jobs(rollouts: NestedArrayDict, only_failed: bool = False)[source]#
Training stage: create fine-tune jobs for each agent.
- Parameters:
rollouts (NestedArrayDict, optional) – The rollouts sampled in this iteration.
only_failed (bool, default=False) – Whether to only create fine-tune jobs for shared model groups whose previous fine-tune job failed or was cancelled. If False, fine-tune jobs are created for all shared model groups.
- _stage_log_stats(rollouts: NestedArrayDict)[source]#
Training stage: log the statistics of the rollouts.
- Parameters:
rollouts (NestedArrayDict) – The rollouts sampled in this iteration.
- async _stage_sample_rollouts() NestedArrayDict [source]#
Training stage: sample rollouts from the training environment.
- Returns:
rollouts (NestedArrayDict) – The sampled rollouts.
- classmethod get_checkpoint_base_dir_from_run_id(run_id: str) Path [source]#
Get the base directory for a checkpoint from a run ID.
- Parameters:
run_id (str) – The run ID.
- Returns:
checkpoint_base_dir (Path) – The path to the base directory for the checkpoint.
- get_total_num_iterations() int [source]#
Get the total number of iterations that the trainer will run for.
This is the sum of the number of iterations declared by methods decorated with
attach_progress_bar
.- Returns:
total_iterations (int) – The total number of iterations.
- load_and_set_state_from_checkpoint(from_base_run: bool = False, version: str = 'latest')[source]#
Load and set the experiment state from a checkpoint, if available.
- Parameters:
from_base_run (bool, default=False) – Whether to load the checkpoint artifact from the base run. By default, the artifact is loaded form the current active run.
version (str, default="latest") – The version of the checkpoint to load. If “latest”, the latest checkpoint is loaded. Otherwise, the version should be a string representing the version of the checkpoint to load from W&B. This must be “latest” if
from_base_run
is False.
- Raises:
CheckPointNotFoundError – If the checkpoint file is not found.
- run_analysers(analysers: list[str | type[PureTextRolloutAnalyser]], model_name: str, *, overwrite=False, use_tqdm=True, dry_run=False)[source]#
Run the given analysers on the rollouts of the experiment.
This method can only be called after the experiment has finished.
- Parameters:
analysers (list[str | type[PureTextRolloutAnalyser]]) – The analysers to run. Either the name of the analyser or the analyser class itself.
model_name (str) – The name of the model to use for the analysis.
overwrite (bool, default=False) – Whether to overwrite the existing analysis files, if they exist.
use_tqdm (bool, default=True) – Whether create a progress bar for the analysis.
dry_run (bool, default=False) – Whether to do a dry run using a dummy API, not saving the results.
- save_checkpoint(log: bool = True)[source]#
Save the state of the experiment to a checkpoint.
- Parameters:
log (bool, default=True) – Whether to log the checkpointing.
- train()[source]#
Train the agents in the environment.
Runs the training loop for the specified number of iterations. The training loop consists of the following stages:
Sample rollouts from the training environment.
Log the statistics of the rollouts.
Run the test loop during training.
Create fine-tune jobs for each agent.
Await the completion of the fine-tune jobs.
The training loop can be resumed from a previous checkpoint. If the training loop is resumed, the state of the experiment is loaded from the checkpoint, and the training loop is resumed from the last stage.