nip.utils.malt_forest.forest.reconstruct_malt_forest

nip.utils.malt_forest.forest.reconstruct_malt_forest#

nip.utils.malt_forest.forest.reconstruct_malt_forest(rollouts: NestedArrayDict) list[MaltTree][source]#

Reconstruct a forest of trees from MALT rollouts.

The MALT [MSD+24] trainer samples a set of trees of responses. These are stored in a flat array. This function reconstructs the trees from this flat array.

Parameters:

rollouts (NestedArrayDict) – The rollouts to reconstruct the trees from.

Returns:

malt_forest (list[MaltNode]) – A list of MaltNode objects representing the root nodes of the trees in the forest.