nip.code_validation.agents.OpenAiSharedModelGroup#
- class nip.code_validation.agents.OpenAiSharedModelGroup(hyper_params: HyperParameters, settings: ExperimentSettings, protocol_handler: ProtocolHandler | CodeValidationProtocolHandler, agent_wholes: dict[str, OpenAiWholeAgent], group_name: str)[source]#
A class representing a group of code validation OpenAI agents sharing a model.
Methods Summary
Get the state of the object for pickling.
__init__
(hyper_params, settings, ...)Get the fine-tune job from the OpenAI API.
_make_fine_tune_api_call
(fine_tune_dataset, ...)Make the API call to fine-tune the model.
Get an iterable of agent IDs and names.
create_dpo_fine_tune_job
(...[, job_name])Create a DPO fine-tune job for the agent group given sampled timesteps.
create_supervised_fine_tune_job
(...[, ...])Create a supervised fine-tune job for the agent.
Get a string representation of the error for the fine-tune job.
Get the status of the fine-tune job.
Get the state of the shared model group.
Get the state dictionary of the agent.
set_state
(checkpoint)Set the state of the shared model group from a checkpoint.
Switch to the next model after fine-tuning.
Attributes
client
The OpenAI client to use for interacting with the OpenAI API.
max_message_rounds
The maximum number of message rounds in the protocol.
model_name
The current model name, which may be the base model or a fine-tuned model.
agent_wholes
Methods
- __getstate__() → dict[str, Any][source]#
Get the state of the object for pickling.
We don’t pickle the OpenAI client, as it is not picklable.
- Returns:
state (dict[str, any]) – The state of the object.
- __init__(hyper_params: HyperParameters, settings: ExperimentSettings, protocol_handler: ProtocolHandler | CodeValidationProtocolHandler, agent_wholes: dict[str, OpenAiWholeAgent], group_name: str)[source]#
- _make_fine_tune_api_call(fine_tune_dataset: list[dict], method: Literal['supervised', 'dpo'], job_name: str | None = None)[source]#
Make the API call to fine-tune the model.
- agent_ids_and_names() → Iterable[tuple[int, str]][source]#
Get an iterable of agent IDs and names.
- Yields:
agent_id (int) – The ID of the agent.
agent_name (str) – The name of the agent.
- create_dpo_fine_tune_job(timesteps_per_agent: dict[str, NestedArrayDict], positive_examples_per_agent: dict[str, NestedArrayDict], negative_examples_per_agent: dict[str, NestedArrayDict], job_name: str | None = None)[source]#
Create a DPO fine-tune job for the agent group given sampled timesteps.
This method generates a dataset of examples ready to pass to the fine-tune API.
- Parameters:
timesteps_per_agent (dict[str, NestedArrayDict]) – The data for each agent in the group. Each agent’s data is a nested dictionary of arrays, which are timesteps selected from the rollouts.
positive_examples_per_agent (dict[str, NestedArrayDict]) – The next timestep in the preferred response for each of the timesteps in
timesteps_per_agent
.negative_examples_per_agent (dict[str, NestedArrayDict]) – The next timestep in the non-preferred response for each of the timesteps in
timesteps_per_agent
.job_name (str, optional) – A name for the job, to make it more easily identifiable.
- create_supervised_fine_tune_job(rollouts_per_agent: dict[str, NestedArrayDict], guess_replaced_rollouts: dict[str, NestedArrayDict] = {}, job_name: str | None = None)[source]#
Create a supervised fine-tune job for the agent.
This method generates a dataset of examples ready to pass to the fine-tune API.
- Parameters:
rollouts_per_agent (dict[str, NestedArrayDict]) –
The sampled rollouts for each agent. Each is a nested dictionary of arrays with keys:
”round” (batch round): The current round number.
”message_history” (batch round round channel): The history of messages exchanged between the agents in each channel.
”message_agent_id” (batch round round channel): The id of the agent who messaged at a round-channel pair.
”raw_message_history” (batch round round agent): The raw message generated by each model in each timestep.
”question” (batch round): The problem text.
”solution” (batch round): The proposed solution text.
”y” (batch round): The true label (0 for incorrect, 1 for correct).
”prover_stance” (batch round): When randomizing the prover stance, the verdict that the prover is arguing for, where 0 means “reject” and 1 means “accept”.
guess_replaced_rollouts (dict[str, NestedArrayDict], default={}) – Additional rollouts for the verifier agents where the verifier’s guess is to be replaced with the true label. In these the verifier’s guess will be replaced with either ‘Decision: accept’ or ‘Decision: reject’ based on the true label.
job_name (str, optional) – A name for the job, to make it more easily identifiable.
- get_fine_tune_job_error_repr() → str[source]#
Get a string representation of the error for the fine-tune job.
- get_fine_tune_job_status() → Literal['pending', 'running', 'succeeded', 'failed', 'cancelled'][source]#
Get the status of the fine-tune job.
- get_state() → PureTextSharedModelGroupState[source]#
Get the state of the shared model group.
- get_state_dict() → dict[source]#
Get the state dictionary of the agent.
- Returns:
state_dict (dict) – The state dictionary of the agent.