nip.trainers.malt_pure_text._generate_response_tree

nip.trainers.malt_pure_text._generate_response_tree#

nip.trainers.malt_pure_text._generate_response_tree(hyper_params: HyperParameters, protocol_handler: ProtocolHandler, environment: PureTextEnvironment, combined_agent: PureTextCombinedWhole, data_batch: NestedArrayDict | None = None) list[list[_PartialRolloutNode]][source]#

Generate the tree of responses for a single datapoint.

This generates a tree of partial rollouts, where the children of each node are the one-step continuations of the node formed by generating multiple different responses for each active agent at that time step. At each step we sample hyper_params.pure_text_malt.num_responses_per_timestep responses.

The output tree is stratified by the level in the tree, with the root node (empty partial rollout) at the first level. Note that in general, the tree will not be fully generated, because the environment may terminate before the maximum number of message rounds is reached.

Parameters:
  • hyper_params (HyperParameters) – The parameters of the experiment.

  • protocol_handler (ProtocolHandler) – The interaction protocol handler for the experiment.

  • environment (PureTextEnvironment) – The environment to sample a rollout in.

  • combined_agent (PureTextCombinedWhole) – The combined agent to use for the rollout.

  • data_batch (NestedArrayDict, optional) – The data batch to use for the rollout. If None, the data batch will be sampled from the dataset.

Returns:

partial_rollouts_by_level (list[list[_PartialRolloutNode]]) – The tree of responses, stratified by level.