nip.trainers.malt_pure_text

nip.trainers.malt_pure_text#

Multi-Agent LLM Training (MALT) for text-based environments that only use APIs.

In the MALT protocol [MSD+24], we sample multiple responses per timestep from the agents. This means that for each datapoint we have a tree of responses. For each agent A, at each decision point for A we look at the expected reward for A for each of the responses. We threshold this expected reward to get a binary classification label for each response. We select good-bad pairs of these, and train using Direct Preference Optimization [RSM+23].

Functions

_compute_tree_expected_reward(...)

Compute the expected reward for each agent at each node of the tree.

_generate_response_tree(hyper_params, ...[, ...])

Generate the tree of responses for a single datapoint.

_sample_positive_and_negative_examples(...)

Sample positive and negative examples for each node in the tree of responses.

_tree_iter(partial_rollouts_by_level[, ...])

Iterate over the tree of responses, either downwards or upwards.

Classes

PureTextMaltTrainer(hyper_params, ...)

Multi-Agent LLM Training (MALT) for text-based environments that only use APIs.

_PartialRolloutNode(current_env_state, ...)

A node in the tree of responses, which is a partially generated rollout.