nip.trainers.malt_pure_text._compute_tree_expected_reward

nip.trainers.malt_pure_text._compute_tree_expected_reward#

nip.trainers.malt_pure_text._compute_tree_expected_reward(partial_rollouts_by_level: list[list[_PartialRolloutNode]], hyper_params: HyperParameters, protocol_handler: ProtocolHandler)[source]#

Compute the expected reward for each agent at each node of the tree.

The expected reward in the average reward that an agent receives over all branches passing through a node. This is stored in the ("agents", "expected_reward") field of the rollouts, which are modified in-place.

This is computed by summing up the total reward for all descendants, proceeding from the leaves to the root, and dividing by the number of branches passing through the node.

We also threshold the expected reward to get a binary classification label for each response. This is stored in the ("agents", "is_positive_example") field.

Parameters:
  • partial_rollouts_by_level (list[list[_PartialRolloutNode]]) – The tree of responses, stratified by level. These are modified in-place, where we add ("agents", "expected_reward") and ("agents", "is_positive_example") fields containing the expected reward for each agent at each node.

  • hyper_params (HyperParameters) – The parameters of the experiment.

  • protocol_handler (ProtocolHandler) – The interaction protocol handler for the experiment.