- class nip.code_validation.agents.OpenAiWholeAgent(hyper_params: HyperParameters, settings: ExperimentSettings, agent_name: str, protocol_handler: ProtocolHandler | CodeValidationProtocolHandler)[source]#
The whole agent for code validation, using OpenAI’s API.
Methods Summary
(data, environment)Call self as a function.
Get the state of the object for pickling.
(hyper_params, settings, agent_name, ...)_build_chat_messages_prompt
(message_history, ...)Construct the chat history ready to feed to the API.
(chat_messages_prompt)Generate a dummy response to a chat prompt.
Generate the next message and decision for the agent, with retries.
(chat_messages_prompt)Call the OpenAI API to generate the next message.
(timesteps[, ...])Build the dataset for fine-tuning the agent given sampled timesteps.
(data, environment)Forward pass through the agent policy head.
Get the state of the agent part as a dict.
(checkpoint)Set the state of the agent from a checkpoint.
The ID of the agent.
for the agent.base_model_name
The base OpenAI model name, before any fine-tuning.
The OpenAI client to use for interacting with the OpenAI API.
The keys required by the module.
Whether the agent is a prover.
Whether the agent is a verifier.
The maximum number of message rounds in the protocol.
The OpenAI model name, including any fine-tuning.
The number of message channels visible to the agent.
The keys produced by the module.
The pretrained models used by the agent.
The template for the system prompt.
The indices of the message channels visible to the agent.
The mask for the message channels visible to the agent.
The names of the message channels visible to the agent.
- __call__(data: NestedArrayDict, environment: PureTextEnvironment) NestedArrayDict [source]#
Call self as a function.
- __getstate__() dict[str, Any] [source]#
Get the state of the object for pickling.
We don’t pickle the OpenAI client, as it is not picklable.
- Returns:
state (dict[str, any]) – The state of the object.
- __init__(hyper_params: HyperParameters, settings: ExperimentSettings, agent_name: str, protocol_handler: ProtocolHandler | CodeValidationProtocolHandler)[source]#
- _build_chat_messages_prompt(message_history: Annotated[ndarray[Any, dtype[_ScalarType_co]], 'round channel'], message_agent_id: Annotated[ndarray[Any, dtype[_ScalarType_co]], 'round channel'], raw_message_history: Annotated[ndarray[Any, dtype[_ScalarType_co]], 'round agent'], round_id: int, question: str, solution: str, prover_stance: int, seed: int, y: int | None = None, ensure_last_message_is_assistant: bool = False, replace_last_message_with_true_label: bool = False) list[PromptMessage] [source]#
Construct the chat history ready to feed to the API.
- Parameters:
message_history (String[NDArray, "round channel"]) – The array of messages in the chat history.
message_agent_id (String[NDArray, "round channel"]) – The id of the agent who messaged at a round-channel pair.
raw_message_history (String[NDArray, "round agent"]) – The raw message generated by each model in each timestep.
round_id (int) – The current round number.
channel_name (str) – The name of the message channel.
question (str) – The problem text.
solution (str) – The proposed solution text.
seed (int) – The per-environment seed.
y (int, optional) – The true label (0 for incorrect, 1 for correct). Only used if
is set to True.ensure_last_message_is_assistant (bool, default=False) – Whether to ensure the last message is from the assistant, by removing messages from the user.
replace_last_message_with_true_label (bool, default=False) – Whether to replace the last message with ‘Decision: accept’ or ‘Decision: reject’ based on the true label. If this is set to True,
must be provided. Only used ifensure_last_message_is_assistant
is set to True.
- Returns:
chat_messages (list[PromptMessage]) – The chat messages ready to feed to the API.
- Raises:
AgentNotActiveInChannelError – If
is set to True and the agent is not active in the channel (i.e. there would be no messages in the chat history).
- _generate_dummy_response(chat_messages_prompt: list[PromptMessage]) str [source]#
Generate a dummy response to a chat prompt.
- Parameters:
chat_messages_prompt (list[PromptMessage]) – The chat messages prompt to generate a response to.
- Returns:
response (str) – The dummy response generated.
- _generate_next_message_and_decision(message_history: Annotated[ndarray[Any, dtype[_ScalarType_co]], 'round channel'], message_agent_id: Annotated[ndarray[Any, dtype[_ScalarType_co]], 'round channel'], raw_message_history: Annotated[ndarray[Any, dtype[_ScalarType_co]], 'round agent'], round_id: int, question: str, solution: str, seed: int, prover_stance: int) _ParsedChatCompletion [source]#
Generate the next message and decision for the agent, with retries.
This message takes a single message history and builds and runs the API request to generate the next action, which can be a message or a decision.
If the there is an error in the generation, this method will retry a number of times before raising an exception (detailed below).
- Parameters:
message_history (String[NDArray, "round channel"]) – The array of messages in the chat history.
message_agent_id (String[NDArray, "round channel"]) – The id of the agent who messaged at a round-channel pair.
raw_message_history (String[NDArray, "round agent"]) – The raw message generated by each model in each timestep.
round_id (int) – The current round number.
channel_name (str) – The name of the message channel.
question (str) – The problem text.
solution (str) – The proposed solution text.
seed (int) – The per-environment seed.
prover_stance (int) – The verdict that the prover is arguing for, where 0 means “reject” and 1 means “accept”. (Currently ignored.)
- Returns:
parsed_chat_completion (_ParsedChatCompletion) – The parsed chat completion output. See the
class for details.- Raises:
ContentFilterError – If the agent’s response is blocked by a content filter.
UnknownFinishReasonError – If the agent finishes generating for an unknown reason.
InvalidResponseError – If the agent generates a response in an invalid format.
InvalidDecisionError – If the agent generates an invalid decision (i.e. not accept or reject).
- _make_generation_api_call(chat_messages_prompt: list[PromptMessage]) tuple[str | None, Literal['stop', 'length', 'tool_calls', 'content_filter', 'function_call']] [source]#
Call the OpenAI API to generate the next message.
- Parameters:
chat_messages_prompt (list[PromptMessage]) – The message history to feed to the API. A list of dicts with keys “role” and “content”.
- Returns:
completion_text (str | None) – The text of the completion generated by the API.
finish_reason (Literal[“stop”, “length”, “tool_calls”, “content_filter”,)
”function_call”] – The reason for finishing the generation.
- build_fine_tune_dataset(timesteps: NestedArrayDict, ensure_last_message_is_assistant: bool = True, replace_verifier_guess_with_true_label: bool = False) list[dict[Literal['messages'], list[PromptMessage]]] [source]#
Build the dataset for fine-tuning the agent given sampled timesteps.
A ‘timestep’ is just a single step in the environment. This holds all information about the environment state, including past messages. Generally, one wants to fine-tune on the final timestep of a rollout, but we can fine-tune on any timestep.
This method generates a dataset of examples ready to pass to the fine-tune API.
- Parameters:
timesteps (NestedArrayDict) –
The sampled timesteps. This is usually the final step in each rollout, but it can be any step if we want to fine-tune on an initial segment of the rollout. A nested dictionary of arrays with keys:
”round” (batch): The current round number.
(“next”, “message_history”) (batch round channel): The history of messages exchanged between the agents in each channel.
(“next”, “message_agent_id”) (batch round channel): The id of the agent who messaged at a round-channel pair.
(“next”, “raw_message_history”) (batch round agent): The raw message generated by each model in each timestep.
”question” (batch): The problem text.
”solution” (batch): The proposed solution text.
”prover_stance” (batch): When randomizing the prover stance, the verdict that the prover is arguing for, where 0 means “reject” and 1 means “accept”.
ensure_last_message_is_assistant (bool, default=True) – Make sure the last message in each rollout comes from the assistant, by removing messages from the end until this is true. This makes sense when we want to do supervised fine-tuning, as the last message should be the assistant’s message.
replace_verifier_guess_with_true_label (bool, default=False) – Whether to replace the verifier’s guess with the true label. If this is set to True, and the agent is a verifier, the verifier’s guess will be replaced with either ‘Decision: accept’ or ‘Decision: reject’ based on the true label. Only makes sense when
is set to True.
- Returns:
fine_tune_dataset (list[dict[Literal[“messages”], list[PromptMessage]]]) – The dataset for fine-tuning the agent. This is a list of examples, where each example is a dictionary with key “messages”, whose value is a list of dictionaries with keys “role” and “content”.
- forward(data: NestedArrayDict, environment: PureTextEnvironment) NestedArrayDict [source]#
Forward pass through the agent policy head.
- Parameters:
data (NestedArrayDict) –
The input data. A nested dictionary of arrays with keys:
”round” (…): The current round number.
”seed” (…): The per-environment seed.
- ”message_history” (… round channel): The history of messages exchanged
between the agents in each channel.
- ”message_agent_id” (… round channel): The id of the agent who messaged
at a round-channel pair.
- ”raw_message_history” (… round agent): The raw message generated by each
model in each timestep.
”question” (…): The problem text.
”solution” (…): The proposed solution text.
”prover_stance” (…): When randomizing the prover stance, the verdict that the prover is arguing for, where 0 means “reject” and 1 means “accept”.
environment (PureTextEnvironment) – The environment in which the agent is running.
- Returns:
output_data (NestedArrayDict) – The output of running the agent policy head. This is a nested dictionary of arrays with keys:
”message” (… channel): The message output by the agent per channel.
- ”raw_message” (…): The raw message generated by the model, before
- ”decision” (…): The decision made by the agent. This is either 0
(reject), 1 (accept) or 2 (no decision).
- ”retry_count” (… channel): The number of retries before the generation
was successful.
- ”token_limit” (… channel): A boolean array indicating whether the
generation was stopped because the maximum number of tokens was reached.
- get_state_dict() dict [source]#
Get the state of the agent part as a dict.
This method should be implemented by subclasses capable of saving their state.
- Returns:
state_dict (dict) – The state of the agent part.
- set_state(checkpoint: AgentState)[source]#
Set the state of the agent from a checkpoint.
This method should be overridden by subclasses to restore the state of the agent from a checkpoint.
- Parameters:
checkpoint (AgentCheckpoint) – The checkpoint to restore the state from.