nip.code_validation.agents.OpenAiWholeAgent#

class nip.code_validation.agents.OpenAiWholeAgent(hyper_params: HyperParameters, settings: ExperimentSettings, agent_name: str, protocol_handler: ProtocolHandler | CodeValidationProtocolHandler)[source]#

The whole agent for code validation, using OpenAI’s SDK.

The SDK can be used to interact with OpenAI’s API, as well as other APIs like OpenRouter.

Methods Summary

__call__(data, environment)

Run a forward pass through the agent, with some safety checks.

__init__(hyper_params, settings, agent_name, ...)

_build_chat_messages_prompt(message_history, ...)

Construct the chat history ready to feed to the API.

_generate_dummy_response(chat_messages_prompt)

Generate a dummy response to a chat prompt.

_generate_next_message_and_decision(...)

Generate the next message and decision for the agent, with retries.

_handle_chat_completion_error(completion)

Handle any errors in a chat completion, raising exceptions as necessary.

_make_generation_api_call(chat_messages_prompt)

Call the OpenAI API to generate the next message.

build_fine_tune_dataset(timesteps[, ...])

Build the dataset for fine-tuning the agent given sampled timesteps.

forward(data, environment)

Forward pass through the agent policy head.

get_state_dict()

Get the state of the agent part as a dict.

set_state(checkpoint)

Set the state of the agent from a checkpoint.

Attributes

agent_id

The ID of the agent.

agent_level_in_keys

agent_level_out_keys

agent_spec

The CodeValidationAgentSpec for the agent.

base_model_name

The base OpenAI model name, before any fine-tuning.

env_level_in_keys

env_level_out_keys

in_keys

The keys required by the module.

is_prover

Whether the agent is a prover.

is_verifier

Whether the agent is a verifier.

max_message_rounds

The maximum number of message rounds in the protocol.

model_name

The OpenAI model name, including any fine-tuning.

num_visible_message_channels

The number of message channels visible to the agent.

openai_client

The OpenAI client to use for interacting with the OpenAI SDK.

out_keys

The keys produced by the module.

required_pretrained_models

The pretrained models used by the agent.

shared_model_group

system_prompt_template

The template for the system prompt.

visible_message_channel_indices

The indices of the message channels visible to the agent.

visible_message_channel_mask

The mask for the message channels visible to the agent.

visible_message_channel_names

The names of the message channels visible to the agent.

agent_params

protocol_handler

Methods

async __call__(data: NestedArrayDict, environment: PureTextEnvironment) NestedArrayDict[source]#

Run a forward pass through the agent, with some safety checks.

Parameters:
Returns:

output (NestedArrayDict) – The output of the forward pass on the input.

__init__(hyper_params: HyperParameters, settings: ExperimentSettings, agent_name: str, protocol_handler: ProtocolHandler | CodeValidationProtocolHandler)[source]#
_build_chat_messages_prompt(message_history: Annotated[ndarray[Any, dtype[_ScalarType_co]], 'round channel'], message_agent_id: Annotated[ndarray[Any, dtype[_ScalarType_co]], 'round channel'], raw_message_history: Annotated[ndarray[Any, dtype[_ScalarType_co]], 'round agent'], round_id: int, question: str, solution: str, prover_stance: int, seed: int, y: int | None = None, ensure_last_message_is_assistant: bool = False, replace_last_message_with_true_label: bool = False, allow_supervisor_message: bool = True) list[PromptMessage][source]#

Construct the chat history ready to feed to the API.

Parameters:
  • message_history (String[NDArray, "round channel"]) – The array of messages in the chat history.

  • message_agent_id (String[NDArray, "round channel"]) – The id of the agent who messaged at a round-channel pair.

  • raw_message_history (String[NDArray, "round agent"]) – The raw message generated by each model in each timestep.

  • round_id (int) – The current round number.

  • channel_name (str) – The name of the message channel.

  • question (str) – The problem text.

  • solution (str) – The proposed solution text.

  • seed (int) – The per-environment seed.

  • y (int, optional) – The true label (0 for incorrect, 1 for correct). Only used if replace_last_message_with_true_label is set to True.

  • ensure_last_message_is_assistant (bool, default=False) – Whether to ensure the last message is from the assistant, by removing messages from the user.

  • replace_last_message_with_true_label (bool, default=False) – Whether to replace the last message with ‘Decision: accept’ or ‘Decision: reject’ based on the true label. If this is set to True, y must be provided. Only used if ensure_last_message_is_assistant is set to True.

  • allow_supervisor_message (bool, default=True) – Whether to allow the supervisor message in the chat history. The supervisor message is a message that is appended to the chat history before being sent to the model. It is used to provide additional context or instructions to the model. If this is set to False, the supervisor message will not be included in the chat history. If this is set to True, the supervisor message will be included depending on the agent’s settings.

Returns:

chat_messages (list[PromptMessage]) – The chat messages ready to feed to the API.

Raises:

AgentNotActiveInChannelError – If ensure_last_message_is_assistant is set to True and the agent is not active in the channel (i.e. there would be no messages in the chat history).

_generate_dummy_response(chat_messages_prompt: list[PromptMessage]) str[source]#

Generate a dummy response to a chat prompt.

Parameters:

chat_messages_prompt (list[PromptMessage]) – The chat messages prompt to generate a response to.

Returns:

response (str) – The dummy response generated.

async _generate_next_message_and_decision(message_history: Annotated[ndarray[Any, dtype[_ScalarType_co]], 'round channel'], message_agent_id: Annotated[ndarray[Any, dtype[_ScalarType_co]], 'round channel'], raw_message_history: Annotated[ndarray[Any, dtype[_ScalarType_co]], 'round agent'], round_id: int, question: str, solution: str, seed: int, prover_stance: int) _ParsedChatCompletion[source]#

Generate the next message and decision for the agent, with retries.

This message takes a single message history and builds and runs the API request to generate the next action, which can be a message or a decision.

If the there is an error in the generation, this method will retry a number of times before raising an exception (detailed below).

Parameters:
  • message_history (String[NDArray, "round channel"]) – The array of messages in the chat history.

  • message_agent_id (String[NDArray, "round channel"]) – The id of the agent who messaged at a round-channel pair.

  • raw_message_history (String[NDArray, "round agent"]) – The raw message generated by each model in each timestep.

  • round_id (int) – The current round number.

  • channel_name (str) – The name of the message channel.

  • question (str) – The problem text.

  • solution (str) – The proposed solution text.

  • seed (int) – The per-environment seed.

  • prover_stance (int) – The verdict that the prover is arguing for, where 0 means “reject” and 1 means “accept”. (Currently ignored.)

Returns:

parsed_chat_completion (_ParsedChatCompletion) – The parsed chat completion output. See the _ParsedChatCompletion class for details.

Raises:
_handle_chat_completion_error(completion: ChatCompletion)[source]#

Handle any errors in a chat completion, raising exceptions as necessary.

The OpenRouter API indicates an error using the (non-standard) “error” attribute in the completion object. This method checks for this and raises an appropriate exception if necessary.

Parameters:

completion (OpenAiChatCompletion) – The chat completion object to check for errors.

Raises:
async _make_generation_api_call(chat_messages_prompt: list[PromptMessage]) tuple[str | None, Literal['stop', 'length', 'tool_calls', 'content_filter', 'function_call']][source]#

Call the OpenAI API to generate the next message.

Parameters:

chat_messages_prompt (list[PromptMessage]) – The message history to feed to the API. A list of dicts with keys “role” and “content”.

Returns:

  • completion_text (str | None) – The text of the completion generated by the API.

  • finish_reason (Literal[“stop”, “length”, “tool_calls”, “content_filter”,)

  • ”function_call”] – The reason for finishing the generation.

build_fine_tune_dataset(timesteps: NestedArrayDict, ensure_last_message_is_assistant: bool = True, replace_verifier_guess_with_true_label: bool = False, use_next_message_history: bool = True) list[SupervisedDatasetItem][source]#

Build the dataset for fine-tuning the agent given sampled timesteps.

A ‘timestep’ is just a single step in the environment. This holds all information about the environment state, including past messages. Generally, one wants to fine-tune on the final timestep of a rollout, but we can fine-tune on any timestep.

This method generates a dataset of examples ready to pass to the fine-tune API.

Parameters:
  • timesteps (NestedArrayDict) –

    The sampled timesteps. This is usually the final step in each rollout, but it can be any step if we want to fine-tune on an initial segment of the rollout. A nested dictionary of arrays with keys:

    • ”round” (batch): The current round number.

    • ”message_history” (batch round channel): The history of messages exchanged between the agents in each channel, up to the current round.

    • ”message_agent_id” (batch round channel): The id of the agent who messaged at a round-channel pair, up to the current round.

    • ”raw_message_history” (batch round agent): The raw message generated by each model in each timestep, up to the current round.

    • (“next”, “message_history”) (batch round channel): The history of messages exchanged between the agents in each channel, up to the next round.

    • (“next”, “message_agent_id”) (batch round channel): The id of the agent who messaged at a round-channel pair, up to the next round.

    • (“next”, “raw_message_history”) (batch round agent): The raw message generated by each model in each timestep, up to the next round.

    • ”question” (batch): The problem text.

    • ”solution” (batch): The proposed solution text.

    • ”prover_stance” (batch): When randomizing the prover stance, the verdict that the prover is arguing for, where 0 means “reject” and 1 means “accept”.

  • ensure_last_message_is_assistant (bool, default=True) – Make sure the last message in each rollout comes from the assistant, by removing messages from the end until this is true. This makes sense when we want to do supervised fine-tuning, as the last message should be the assistant’s message.

  • replace_verifier_guess_with_true_label (bool, default=False) – Whether to replace the verifier’s guess with the true label. If this is set to True, and the agent is a verifier, the verifier’s guess will be replaced with either ‘Decision: accept’ or ‘Decision: reject’ based on the true label. Only makes sense when ensure_last_message_is_assistant is set to True.

  • use_next_message_history (bool, default=True) – Whether to use the message history from the next timestep (located in the “next” sub-dictionary). This is the message history which includes the messages sent in the current round. If this is set to False, the message history up to the current round will be used.

Returns:

fine_tune_dataset (list[SupervisedDatasetItem]) – The dataset for fine-tuning the agent. This is a list of examples, where each example is a dictionary with key “messages”, whose value is a list of dictionaries with keys “role” and “content”.

async forward(data: NestedArrayDict, environment: PureTextEnvironment) NestedArrayDict[source]#

Forward pass through the agent policy head.

Parameters:
  • data (NestedArrayDict) –

    The input data. A nested dictionary of arrays with keys:

    • ”round” (…): The current round number.

    • ”seed” (…): The per-environment seed.

    • ”message_history” (… round channel): The history of messages exchanged between the agents in each channel.

    • ”message_agent_id” (… round channel): The id of the agent who messaged at a round-channel pair.

    • ”raw_message_history” (… round agent): The raw message generated by each model in each timestep.

    • ”question” (…): The problem text.

    • ”solution” (…): The proposed solution text.

    • ”prover_stance” (…): When randomizing the prover stance, the verdict that the prover is arguing for, where 0 means “reject” and 1 means “accept”.

  • environment (PureTextEnvironment) – The environment in which the agent is running.

Returns:

output_data (NestedArrayDict) – The output of running the agent policy head. This is a nested dictionary of arrays with keys:

  • ”message” (… channel): The message output by the agent per channel.

  • ”raw_message” (…): The raw message generated by the model, before parsing.

  • ”prompt” (… message component): The prompt used to generate the message. Each prompt is s sequence of chat messages composed of various components, as defined in the PromptMessage class. Here we store these components in a 2D array for each batch item, where the first dimension is the chat message index and the second dimension is the component index.

  • ”decision” (…): The discrete decision from the verifier model, with the following meanings: - 0: reject - 1: accept - 2: no decision - 3: end with neither accept nor reject

  • ”continuous_decision” (…): The continuous decision made by the agent. This is a number between -1 and 1, where -1 is “reject” and 1 is “accept”.

  • ”raw_decision” (…): The raw decision text from the verifier model. This is the text which appears after “Decision: “ in the completion text.

  • ”valid_response” (…): A boolean array indicating whether the generation

    was valid.

  • ”retry_count” (… channel): The number of retries before the generation was successful.

  • ”token_limit” (… channel): A boolean array indicating whether the generation was stopped because the maximum number of tokens was reached.

get_state_dict() dict[source]#

Get the state of the agent part as a dict.

This method should be implemented by subclasses capable of saving their state.

Returns:

state_dict (dict) – The state of the agent part.

set_state(checkpoint: AgentState)[source]#

Set the state of the agent from a checkpoint.

This method should be overridden by subclasses to restore the state of the agent from a checkpoint.

Parameters:

checkpoint (AgentCheckpoint) – The checkpoint to restore the state from.