nip.trainers.malt_pure_text

nip.trainers.malt_pure_text#

Multi-Agent LLM Training (MALT) for text-based environments that only use APIs.

In the MALT protocol [MSD+24], we sample multiple responses per timestep from the agents. This means that for each datapoint we have a tree of responses. For each agent A, at each decision point for A we look at the expected reward for A for each of the responses. We then select preference pairs of responses from these and train using Direct Preference Optimization [RSM+23]. The way pairs are selected is determined by the hyper_params.pure_text_malt.pair_selection_method parameter, which can be one of the following:

  • “positive_negative”: Selects a response where the agent’s expected reward is above a certain threshold (by default the reward mid-point) and a response where the agent’s expected reward is below this threshold.

  • “interval”: Selects a pair of responses where the difference in expected reward is above a certain threshold. This threshold is computed as interval_threshold_proportion times the difference between the maximum and minimum possible reward for the agent.

It is also possible do some rounds of Expert Iteration (EI) before doing MALT. The PureTextMaltTrainer class inherits from the PureTextEiTrainer class, which implements the EI protocol, and allows running EI for a number of iterations specified by the hyper_params.pure_text_malt.num_initial_ei_iterations parameter.

Functions

_dispatch_to_trainer(method)

Dispatch a method to the appropriate trainer.

_tree_iter(partial_rollouts_by_level[, ...])

Iterate over the tree of responses, either downwards or upwards.

Classes

PureTextMaltTrainer(hyper_params, ...)

Multi-Agent LLM Training (MALT) for text-based environments that only use APIs.

_PartialRolloutNode(current_env_state, ...)

A node in the tree of responses, which is a partially generated rollout.