nip.trainers.malt_pure_text.PureTextMaltTrainer#
- class nip.trainers.malt_pure_text.PureTextMaltTrainer(hyper_params: HyperParameters, scenario_instance: ScenarioInstance, settings: ExperimentSettings)[source]#
Multi-Agent LLM Training (MALT) for text-based environments that only use APIs.
In the MALT protocol [MSD+24], we sample multiple responses per timestep from the agents. This means that for each datapoint we have a tree of responses. For each agent
A
, at each decision point forA
we look at the expected reward forA
for each of the responses. We threshold this expected reward to get a binary classification label for each response. We select good-bad pairs of these, and train using Direct Preference Optimization [RSM+23].- Parameters:
hyper_params (HyperParameters) – The parameters of the experiment.
scenario_instance (ScenarioInstance) – The components of the experiment.
settings (ExperimentSettings) – The instance-specific settings of the experiment, like device, logging, etc.
Methods Summary
__init__
(hyper_params, scenario_instance, ...)_add_file_to_wandb_artifact
(artifact_name, ...)Add a file to a W&B artifact, creating the artifact if it doesn't exist.
Check if the test loop should be run in the current iteration.
_extract_transcripts_and_prompts
(rollouts, ...)Extract the raw and processed transcripts, and prompts, from the rollouts.
_get_fine_tune_job_name
(shared_model_group)Get a name for the fine-tune job for the given shared model group.
_get_log_stats
(rollouts, *[, train])Get the statistics to log for the given rollouts.
Get metadata relevant to the current experiment, for saving the checkpoint.
Get the proportion of rollouts to replace the guess with the true label.
Initialise the state of the experiment.
_load_rollouts
(iterations)Load the rollouts from the checkpoint directory.
Load the experiment state from a checkpoint for the active run, if available.
Load the experiment state from a base run checkpoint.
_sample_rollouts
(environment, iteration[, ...])Sample rollouts in the environment.
Sample rollouts for a single environment.
_save_rollouts
(rollouts, environment[, ...])Save the rollouts to the checkpoint directory.
Training stage: await the completion of the fine-tune jobs.
_stage_create_fine_tune_jobs
(rollouts)Training stage: create fine-tune jobs for each agent.
_stage_log_stats
(rollouts)Training stage: log the statistics of the rollouts.
Training stage: run the test loop.
Training stage: sample rollouts from the training environment.
Get the base directory for a checkpoint from a run ID.
Get the total number of iterations that the trainer will run for.
Load and set the experiment state from a checkpoint, if available.
run_analysers
(analysers, model_name, *[, ...])Run the given analysers on the rollouts of the experiment.
save_checkpoint
([log])Save the state of the experiment to a checkpoint.
train
()Train the agents in the environment.
Attributes
agent_names
The names of the agents in the scenario.
agent_wholes
The 'whole' part of each agent.
checkpoint_analysis_dir
The directory to save the rollout analysis to.
checkpoint_base_dir
The path to the directory containing the checkpoint.
checkpoint_metadata_path
The path to the checkpoint metadata file.
checkpoint_params_path
The path to the parameters file for the checkpoint.
checkpoint_rollouts_dir
The directory to save the rollouts to.
checkpoint_state_path
The path to the checkpoint state file.
combined_agent
The agents combined into a single operator.
max_message_rounds
The maximum number of message rounds in the protocol.
num_agents
The number of agents in the scenario.
processed_transcripts_dir
The directory to save the processed transcripts to.
prompts_dir
The directory to save the prompts to.
protocol_handler
The protocol handler for the experiment.
raw_transcripts_dir
The directory to save the raw transcripts to.
shared_model_groups
The agents grouped by having a shared model.
state
The state of the experiment.
test_environment
The test environment.
train_environment
The training environment.
Methods
- __init__(hyper_params: HyperParameters, scenario_instance: ScenarioInstance, settings: ExperimentSettings)[source]#
- _add_file_to_wandb_artifact(artifact_name: str, artifact_type: str, file_path: Path)[source]#
Add a file to a W&B artifact, creating the artifact if it doesn’t exist.
If the artifact already exists, we add the file to the existing artifact, creating a new version.
- _check_if_run_test_loop() bool [source]#
Check if the test loop should be run in the current iteration.
- Returns:
run_test_loop (bool) – Whether the test loop should be run.
- _extract_transcripts_and_prompts(rollouts: NestedArrayDict, environment: PureTextEnvironment) tuple[list[dict], list[dict], list[list[dict[str, list[PromptMessage]]]]] [source]#
Extract the raw and processed transcripts, and prompts, from the rollouts.
The raw transcript is the sequence of outputs generated by the models, per agent, while the processed transcript is the result of processing these and extracting the message per channel.
Note that in the raw transcripts the messages are per agent, while in the processed transcripts the messages are per channel.
The transcripts have variable length, where if a round has no messages from any agent, we declare that the end of the transcript.
- Parameters:
rollouts (NestedArrayDict) –
The rollouts to extract the transcripts from. A NestedArrayDict with keys:
”message_history” (batch round round channel) : The message history for each rollout. In each timestep this gives the history of all messages generated up to that point.
”message_agent_id” (batch round round channel) : The ID of the agent that generated each message in the message history.
(“agents”, “raw_message”) (batch round agent) : The raw message generated by each model in each timestep.
(“agents”, “prompt”) (batch round agent message field) : The prompt used by to generate the message for each agent in each timestep.
(“agents”, “decision”) (batch round agent) : The decision made by each agent in each timestep.
(“agents”, “reward”) (batch round agent) : The reward received by each agent in each timestep.
The nested array dict also contains keys which specify the datapoint for each rollout, as extracted by
environment.get_datapoint_from_env_state_as_dict
.environment (PureTextEnvironment) – The environment the rollouts were sampled in.
- Returns:
raw_transcripts (list[dict]) – The raw transcripts. This is a list of transcripts, where each transcript is dictionary containing meta data and a “transcript” key. The value at “transcript” is a list of dictionaries whose keys are the agent names and values are the messages generated by the agents.
processed_transcripts (list[dict]) – The processed transcripts. This is a list of transcripts, where each transcript is dictionary containing meta data and a “transcript” key. The value at “transcript” is a list of dictionaries whose keys are
f"{active_agent_name}@{channel_name}"
and values are the messages in each channel.prompts (list[list[dict[str, list[PromptMessage]]]]) – The prompts used to generate the messages at each timestep. This is a list containing for each batch item a list of dictionaries, one for each round. Each dictionary has the agent names as keys and the prompts used by the agents the as values. The prompts are a list of dictionaries, whose type is specified by the
PromptMessage
class.
- _get_fine_tune_job_name(shared_model_group: PureTextSharedModelGroup) str [source]#
Get a name for the fine-tune job for the given shared model group.
This name is generated from the run id, the iteration number, and the shared model group name, and is used make the job more easily identifiable.
- Parameters:
shared_model_group (PureTextSharedModelGroup) – The shared model group to create the fine-tune job for.
- Returns:
job_name (str) – The name of the fine-tune job.
- _get_log_stats(rollouts: NestedArrayDict, *, train=True) dict [source]#
Get the statistics to log for the given rollouts.
This method extends the base class method to include the MALT-specific statistics.
- Parameters:
rollouts (NestedArrayDict) – The rollouts to get the statistics for.
train (bool, default=True) – Whether the rollouts are from the training environment.
- Returns:
stats (dict) – The statistics to log.
- _get_metadata() dict [source]#
Get metadata relevant to the current experiment, for saving the checkpoint.
- Returns:
metadata (dict) – The metadata to record
- _get_verifier_guess_replacement_proportion(iteration: int) float [source]#
Get the proportion of rollouts to replace the guess with the true label.
For this proportion of the sampled rollouts, we replace the verifier guess with either “Decision: accept” or “Decision: reject” based on the true label.
This value can be annealed over the course of the training.
- Parameters:
iteration (int) – The current iteration number.
- Returns:
proportion (float) – The proportion of rollouts where we replace the guess with the true label.
- Raises:
ValueError – If the annealing type is invalid.
- _initialise_state()[source]#
Initialise the state of the experiment.
This method should be implemented by subclasses to initialise the state of the experiment. This is called at the beginning of training when starting from scratch (i.e. not restoring from a checkpoint).
- _load_rollouts(iterations: int | Iterable[int]) NestedArrayDict [source]#
Load the rollouts from the checkpoint directory.
- _load_state_dict_from_active_run_checkpoint() dict [source]#
Load the experiment state from a checkpoint for the active run, if available.
- Returns:
state_dict (dict) – The state as a dictionary, loaded from W&B
- Raises:
CheckPointNotFoundError – If the checkpoint file is not found.
- _load_state_dict_from_base_run_checkpoint(version: str = 'latest') dict [source]#
Load the experiment state from a base run checkpoint.
- Parameters:
version (str, default="latest") – The version of the checkpoint to load. If “latest”, the latest checkpoint is loaded. Otherwise, the version should be a string representing the version of the checkpoint to load from W&B.
- Returns:
state_dict (dict) – The state as a dictionary, loaded from W&B
- Raises:
CheckPointNotFoundError – If the checkpoint file is not found.
- _sample_rollouts(environment: PureTextEnvironment, iteration: int | Literal['test'], use_tqdm: bool = False, tqdm_desc: str = 'Sampling rollouts') NestedArrayDict [source]#
Sample rollouts in the environment.
We sample
environment.num_envs
rollouts from the environment. A rollout is a sequence of lengthmax_message_rounds
of states in the environment. The sampled rollout nested array dict thus has shape (num_envs, max_message_rounds).- Parameters:
environment (PureTextEnvironment) – The environment to sample rollouts in.
iteration (int | Literal["test"]) – The iteration number, or “test” if the rollouts are from the test set.
use_tqdm (bool) – Whether to create a tqdm progress bar for the rollouts.
tqdm_desc (str) – The description to use for the tqdm progress bar.
- Returns:
rollouts (NestedArrayDict) – The rollouts in the environment. Has batch size (num_envs, max_message_rounds)
- static _sample_rollouts_for_single_environment(args: tuple[HyperParameters, ProtocolHandler, PureTextEnvironment, PureTextCombinedWhole, NestedArrayDict | None]) list[NestedArrayDict] [source]#
Sample rollouts for a single environment.
A single environment is associated with a single datapoint. This method samples rollouts from it.
To implement the MALT training scheme, we need to sample multiple responses per timestep from the agents, and generate a tree of responses.
We also do additional processing and compute various statistics for each node in the tree of responses. It’s more efficient and easier to do this now rather than later, because we have access to the full tree structure. While it’s possible recover this later, it takes a bit of work because the rollouts are stored in arrays.
1. We compute the expected reward for each agent at each node of the tree by summing up the total reward for all descendants, proceeding from the leaves to the root, and dividing by the number of branches passing through the node. This is stored in the
("agents", "expected_reward")
field of the rollouts.2. The expected reward is thresholded using an estimate of the reward mid-points to get a binary classification label for each response, into ‘positive’ and ‘negative’ examples. This is stored in
("agents", "is_positive_example")
.3. We look at each node and check if in its children there is a positive and a negative example. If so, we set the
("agents", "has_positive_and_negative")
field to True. In this case, we randomly sample a positive and a negative example from the children and set the("agents", "sampled_positive_example")
and("agents", "sampled_negative_example")
fields to the corresponding node IDs. Otherwise these fields are set to -1.4. Each node in the response tree gets a unique ID, stored in
_node_id
which has shape(max_message_rounds, )
. This allows reconstructing the tree of responses later, if required, because if the same node ID appears in two different rollouts, then those points in the message history are the same.Shapes
The following are the shapes of the additional fields added to each rollout.
(“agents”, “expected_reward”): “round agent”
(“agents”, “is_positive_example”): “round agent”
(“agents”, “has_positive_and_negative”): “round agent”
(“agents”, “sampled_positive_example”): “round agent”
(“agents”, “sampled_negative_example”): “round agent”
“_node_id”: “round”
Notes
This function is intended to be applied by a pool of workers. As such it must be a static function and take all trainer attributes required as arguments.
- Parameters:
hyper_params (HyperParameters) – The parameters of the experiment.
protocol_handler (ProtocolHandler) – The interaction protocol handler for the experiment.
environment (PureTextEnvironment) – The environment to sample a rollout in.
combined_agent (PureTextCombinedWhole) – The combined agent to use for the rollout.
data_batch (NestedArrayDict, optional) – The data batch to use for the rollout. If None, the data batch will be sampled from the dataset.
- Returns:
sampled_rollouts = list[NestedArrayDict] – The list of sampled rollouts, each of which has batch size (max_message_rounds, ).
- _save_rollouts(rollouts: NestedArrayDict, environment: PureTextEnvironment, iteration: int | None = None)[source]#
Save the rollouts to the checkpoint directory.
- Parameters:
rollouts (NestedArrayDict) – The rollouts to save.
environment (PureTextEnvironment) – The environment the rollouts were sampled in.
iteration (int, optional) – The iteration number. If not provided, the current iteration number is used.
- _stage_create_fine_tune_jobs(rollouts: NestedArrayDict)[source]#
Training stage: create fine-tune jobs for each agent.
- Parameters:
rollouts (NestedArrayDict, optional) – The rollouts sampled in this iteration.
- _stage_log_stats(rollouts: NestedArrayDict)[source]#
Training stage: log the statistics of the rollouts.
- Parameters:
rollouts (NestedArrayDict) – The rollouts sampled in this iteration.
- _stage_sample_rollouts() NestedArrayDict [source]#
Training stage: sample rollouts from the training environment.
- Returns:
rollouts (NestedArrayDict) – The sampled rollouts.
- classmethod get_checkpoint_base_dir_from_run_id(run_id: str) Path [source]#
Get the base directory for a checkpoint from a run ID.
- Parameters:
run_id (str) – The run ID.
- Returns:
checkpoint_base_dir (Path) – The path to the base directory for the checkpoint.
- get_total_num_iterations() int [source]#
Get the total number of iterations that the trainer will run for.
This is the sum of the number of iterations declared by methods decorated with
attach_progress_bar
.- Returns:
total_iterations (int) – The total number of iterations.
- load_and_set_state_from_checkpoint(from_base_run: bool = False, version: str = 'latest')[source]#
Load and set the experiment state from a checkpoint, if available.
- Parameters:
from_base_run (bool, default=False) – Whether to load the checkpoint artifact from the base run. By default, the artifact is loaded form the current active run.
version (str, default="latest") – The version of the checkpoint to load. If “latest”, the latest checkpoint is loaded. Otherwise, the version should be a string representing the version of the checkpoint to load from W&B. This must be “latest” if
from_base_run
is False.
- Raises:
CheckPointNotFoundError – If the checkpoint file is not found.
- run_analysers(analysers: list[str | type[PureTextRolloutAnalyser]], model_name: str, *, overwrite=False, use_tqdm=True, dry_run=False)[source]#
Run the given analysers on the rollouts of the experiment.
This method can only be called after the experiment has finished.
- Parameters:
analysers (list[str | type[PureTextRolloutAnalyser]]) – The analysers to run. Either the name of the analyser or the analyser class itself.
model_name (str) – The name of the model to use for the analysis.
overwrite (bool, default=False) – Whether to overwrite the existing analysis files, if they exist.
use_tqdm (bool, default=True) – Whether create a progress bar for the analysis.
dry_run (bool, default=False) – Whether to do a dry run using a dummy API, not saving the results.
- save_checkpoint(log: bool = True)[source]#
Save the state of the experiment to a checkpoint.
- Parameters:
log (bool, default=True) – Whether to log the checkpointing.
- train()[source]#
Train the agents in the environment.
Runs the training loop for the specified number of iterations. The training loop consists of the following stages:
Sample rollouts from the training environment.
Log the statistics of the rollouts.
Run the test loop during training.
Create fine-tune jobs for each agent.
Await the completion of the fine-tune jobs.
The training loop can be resumed from a previous checkpoint. If the training loop is resumed, the state of the experiment is loaded from the checkpoint, and the training loop is resumed from the last stage.