nip.scenario_base.agents.PureTextSharedModelGroup#

class nip.scenario_base.agents.PureTextSharedModelGroup(hyper_params: HyperParameters, settings: ExperimentSettings, protocol_handler: ProtocolHandler, agent_wholes: Iterable[PureTextWholeAgent], group_name: str)[source]#

A class representing a group of pure text agents which share the same model.

The shared model is fine-tuned on the data from all agents in the group.

Parameters:
  • hyper_params (HyperParameters) – The parameters of the experiment.

  • settings (ExperimentSettings) – The settings of the experiment.

  • protocol_handler (ProtocolHandler) – The protocol handler for the experiment.

  • agent_wholes (Iterable[PureTextWholeAgent]) – The agents in the shared model group.

  • group_name (str) – The name of the shared model group.

Methods Summary

__init__(hyper_params, settings, ...)

agent_ids_and_names()

Get an iterable of agent IDs and names.

create_dpo_fine_tune_job(...[, job_name])

Create a DPO fine-tune job for the agent group given sampled timesteps.

create_supervised_fine_tune_job(...[, ...])

Create a supervised fine-tune job for the agent group given sampled rollouts.

get_fine_tune_job_error_repr()

Get a string representation of the error for the fine-tune job.

get_fine_tune_job_status()

Get the status of the fine-tune job.

get_state()

Get the state of the shared model group.

get_state_dict()

Get the state of the shared model group as a dict.

set_state(checkpoint)

Set the state of the shared model group from a checkpoint.

switch_to_next_model()

Switch to the next model after fine-tuning.

Attributes

max_message_rounds

The maximum number of message rounds in the protocol.

model_name

The current model name, which may be the base model or a fine-tuned model.

Methods

__init__(hyper_params: HyperParameters, settings: ExperimentSettings, protocol_handler: ProtocolHandler, agent_wholes: Iterable[PureTextWholeAgent], group_name: str)[source]#
agent_ids_and_names() Iterable[tuple[int, str]][source]#

Get an iterable of agent IDs and names.

Yields:
  • agent_id (int) – The ID of the agent.

  • agent_name (str) – The name of the agent.

abstract create_dpo_fine_tune_job(timesteps_per_agent: dict[str, NestedArrayDict], positive_examples_per_agent: dict[str, NestedArrayDict], negative_examples_per_agent: dict[str, NestedArrayDict], job_name: str | None = None)[source]#

Create a DPO fine-tune job for the agent group given sampled timesteps.

Parameters:
  • timesteps_per_agent (dict[str, NestedArrayDict]) – The data for each agent in the group. Each agent’s data is a nested dictionary of arrays, which are timesteps selected from the rollouts.

  • positive_examples_per_agent (dict[str, NestedArrayDict]) – The next timestep in the preferred response for each of the timesteps in timesteps_per_agent.

  • negative_examples_per_agent (dict[str, NestedArrayDict]) – The next timestep in the non-preferred response for each of the timesteps in timesteps_per_agent.

abstract create_supervised_fine_tune_job(rollouts_per_agent: dict[str, NestedArrayDict], guess_replaced_rollouts: dict[str, NestedArrayDict] = {}, job_name: str | None = None)[source]#

Create a supervised fine-tune job for the agent group given sampled rollouts.

This method is used to do supervised fine-tuning (as opposed to other methods of fine-tuning, like reinforcement learning).

Parameters:
  • rollouts_per_agent (dict[str, NestedArrayDict]) – The data for each agent in the group, sampled from the environment.

  • guess_replaced_rollouts (dict[str, NestedArrayDict], default={}) – Additional rollouts for the verifier agents where the verifier’s guess is to be replaced with the true label.

  • job_name (str, optional) – A name for the job, to make it more easily identifiable.

abstract get_fine_tune_job_error_repr() str[source]#

Get a string representation of the error for the fine-tune job.

abstract get_fine_tune_job_status() Literal['pending', 'running', 'succeeded', 'failed', 'cancelled'][source]#

Get the status of the fine-tune job.

get_state() PureTextSharedModelGroupState[source]#

Get the state of the shared model group.

get_state_dict() dict[source]#

Get the state of the shared model group as a dict.

This method should be implemented by subclasses capable of saving their state.

Returns:

state_dict (dict) – The state of the shared model group.

set_state(checkpoint: PureTextSharedModelGroupState)[source]#

Set the state of the shared model group from a checkpoint.

This method should be overridden by subclasses to restore the state of the shared model group from a checkpoint.

Parameters:

checkpoint (AgentCheckpoint) – The checkpoint to restore the state from.

abstract switch_to_next_model()[source]#

Switch to the next model after fine-tuning.