nip.scenario_base.agents.PureTextSharedModelGroup#

class nip.scenario_base.agents.PureTextSharedModelGroup(hyper_params: HyperParameters, settings: ExperimentSettings, protocol_handler: ProtocolHandler, agent_wholes: Iterable[PureTextWholeAgent], group_name: str)[source]#

A class representing a group of pure text agents which share the same model.

The shared model is fine-tuned on the data from all agents in the group.

Parameters:
  • hyper_params (HyperParameters) – The parameters of the experiment.

  • settings (ExperimentSettings) – The settings of the experiment.

  • protocol_handler (ProtocolHandler) – The protocol handler for the experiment.

  • agent_wholes (Iterable[PureTextWholeAgent]) – The agents in the shared model group.

  • group_name (str) – The name of the shared model group.

Methods Summary

__init__(hyper_params, settings, ...)

agent_ids_and_names()

Get an iterable of agent IDs and names.

create_dpo_fine_tune_job(...[, job_name])

Create a DPO fine-tune job for the agent group given sampled timesteps.

create_supervised_fine_tune_job(...[, ...])

Create a supervised fine-tune job for the agent group given sampled rollouts.

eval()

Set the agent group to evaluation mode.

fine_tune_job_failed()

Check if the fine-tune job has failed.

get_fine_tune_job_error_repr()

Get a string representation of the error for the fine-tune job.

get_fine_tune_job_status()

Get the status of the fine-tune job.

get_state()

Get the state of the shared model group.

get_state_dict()

Get the state of the shared model group as a dict.

set_state(checkpoint)

Set the state of the shared model group from a checkpoint.

switch_to_next_model()

Switch to the next model after fine-tuning.

train()

Set the agent group to training mode.

wait_for_ready([timeout])

Wait for the agent group to be ready.

Attributes

is_trainable

lora_alpha

The computed LoRA alpha value for the group.

max_message_rounds

The maximum number of message rounds in the protocol.

model_name

The current model name, which may be the base model or a fine-tuned model.

num_epochs

The number of epochs to train the model for.

rl_learning_rate

The learning rate for this group when using reinforcement learning.

Methods

__init__(hyper_params: HyperParameters, settings: ExperimentSettings, protocol_handler: ProtocolHandler, agent_wholes: Iterable[PureTextWholeAgent], group_name: str)[source]#
agent_ids_and_names() Iterable[tuple[int, str]][source]#

Get an iterable of agent IDs and names.

Yields:
  • agent_id (int) – The ID of the agent.

  • agent_name (str) – The name of the agent.

abstract async create_dpo_fine_tune_job(positive_examples_per_agent: dict[str, NestedArrayDict], negative_examples_per_agent: dict[str, NestedArrayDict], job_name: str | None = None)[source]#

Create a DPO fine-tune job for the agent group given sampled timesteps.

Parameters:
  • positive_examples_per_agent (dict[str, NestedArrayDict]) – The next timestep in the preferred response for each of the timesteps in timesteps_per_agent. Each is a nested array dict with batch size (timestep, ) rather than the usual (batch, round), because we have selected timesteps from the first two dimensions of the batch.

  • negative_examples_per_agent (dict[str, NestedArrayDict]) – The next timestep in the non-preferred response for each of the timesteps in timesteps_per_agent. Each is a nested array dict with batch size (timestep, ) rather than the usual (batch, round), because we have selected timesteps from the first two dimensions of the batch.

  • job_name (str, optional) – A name for the job, to make it more easily identifiable.

abstract async create_supervised_fine_tune_job(rollouts_per_agent: dict[str, NestedArrayDict], guess_replaced_rollouts: dict[str, NestedArrayDict] = {}, job_name: str | None = None)[source]#

Create a supervised fine-tune job for the agent group given sampled rollouts.

This method is used to do supervised fine-tuning (as opposed to other methods of fine-tuning, like reinforcement learning).

Parameters:
  • rollouts_per_agent (dict[str, NestedArrayDict]) – The data for each agent in the group, sampled from the environment.

  • guess_replaced_rollouts (dict[str, NestedArrayDict], default={}) – Additional rollouts for the verifier agents where the verifier’s guess is to be replaced with the true label.

  • job_name (str, optional) – A name for the job, to make it more easily identifiable.

async eval()[source]#

Set the agent group to evaluation mode.

This method may be overridden by subclasses if anything needs to be done when the agent group is set to evaluation mode.

async fine_tune_job_failed() bool[source]#

Check if the fine-tune job has failed.

Returns:

failed (bool) – True if the fine-tune job has failed, False otherwise.

abstract async get_fine_tune_job_error_repr() str[source]#

Get a string representation of the error for the fine-tune job.

abstract async get_fine_tune_job_status() Literal['pending', 'running', 'succeeded', 'failed', 'cancelled', 'not_found'][source]#

Get the status of the fine-tune job.

get_state() PureTextSharedModelGroupState[source]#

Get the state of the shared model group.

get_state_dict() dict[source]#

Get the state of the shared model group as a dict.

This method should be implemented by subclasses capable of saving their state.

Returns:

state_dict (dict) – The state of the shared model group.

set_state(checkpoint: PureTextSharedModelGroupState)[source]#

Set the state of the shared model group from a checkpoint.

This method should be overridden by subclasses to restore the state of the shared model group from a checkpoint.

Parameters:

checkpoint (AgentCheckpoint) – The checkpoint to restore the state from.

abstract async switch_to_next_model()[source]#

Switch to the next model after fine-tuning.

async train()[source]#

Set the agent group to training mode.

This method may be overridden by subclasses if anything needs to be done when the agent group is set to training mode.

async wait_for_ready(timeout: float = 300.0)[source]#

Wait for the agent group to be ready.

Parameters:

timeout (float, default=300.0) – The maximum time to wait for the agent group to be ready, in seconds.

Raises:

TimeoutError – If the agent group is not ready within the timeout period.