nip.code_validation.agents.OpenAiWholeAgent#
- class nip.code_validation.agents.OpenAiWholeAgent(hyper_params: HyperParameters, settings: ExperimentSettings, agent_name: str, protocol_handler: ProtocolHandler | CodeValidationProtocolHandler)[source]#
The whole agent for code validation, using OpenAI’s API.
Methods Summary
__call__
(data, environment)Call self as a function.
Get the state of the object for pickling.
__init__
(hyper_params, settings, agent_name, ...)_build_chat_messages_prompt
(message_history, ...)Construct the chat history ready to feed to the API.
_generate_dummy_response
(chat_messages_prompt)Generate a dummy response to a chat prompt.
Generate the next message and decision for the agent, with retries.
_make_generation_api_call
(chat_messages_prompt)Call the OpenAI API to generate the next message.
build_fine_tune_dataset
(timesteps[, ...])Build the dataset for fine-tuning the agent given sampled timesteps.
forward
(data, environment)Forward pass through the agent policy head.
Get the state of the agent part as a dict.
set_state
(checkpoint)Set the state of the agent from a checkpoint.
Attributes
agent_id
The ID of the agent.
agent_level_in_keys
agent_level_out_keys
agent_spec
The
CodeValidationAgentSpec
for the agent.base_model_name
The base OpenAI model name, before any fine-tuning.
client
The OpenAI client to use for interacting with the OpenAI API.
env_level_in_keys
env_level_out_keys
in_keys
The keys required by the module.
is_prover
Whether the agent is a prover.
is_verifier
Whether the agent is a verifier.
max_message_rounds
The maximum number of message rounds in the protocol.
model_name
The OpenAI model name, including any fine-tuning.
num_visible_message_channels
The number of message channels visible to the agent.
out_keys
The keys produced by the module.
required_pretrained_models
The pretrained models used by the agent.
shared_model_group
system_prompt_template
The template for the system prompt.
visible_message_channel_indices
The indices of the message channels visible to the agent.
visible_message_channel_mask
The mask for the message channels visible to the agent.
visible_message_channel_names
The names of the message channels visible to the agent.
agent_params
protocol_handler
Methods
- __call__(data: NestedArrayDict, environment: PureTextEnvironment) NestedArrayDict [source]#
Call self as a function.
- __getstate__() dict[str, Any] [source]#
Get the state of the object for pickling.
We don’t pickle the OpenAI client, as it is not picklable.
- Returns:
state (dict[str, any]) – The state of the object.
- __init__(hyper_params: HyperParameters, settings: ExperimentSettings, agent_name: str, protocol_handler: ProtocolHandler | CodeValidationProtocolHandler)[source]#
- _build_chat_messages_prompt(message_history: Annotated[ndarray[Any, dtype[_ScalarType_co]], 'round channel'], message_agent_id: Annotated[ndarray[Any, dtype[_ScalarType_co]], 'round channel'], raw_message_history: Annotated[ndarray[Any, dtype[_ScalarType_co]], 'round agent'], round_id: int, question: str, solution: str, prover_stance: int, seed: int, y: int | None = None, ensure_last_message_is_assistant: bool = False, replace_last_message_with_true_label: bool = False) list[PromptMessage] [source]#
Construct the chat history ready to feed to the API.
- Parameters:
message_history (String[NDArray, "round channel"]) – The array of messages in the chat history.
message_agent_id (String[NDArray, "round channel"]) – The id of the agent who messaged at a round-channel pair.
raw_message_history (String[NDArray, "round agent"]) – The raw message generated by each model in each timestep.
round_id (int) – The current round number.
channel_name (str) – The name of the message channel.
question (str) – The problem text.
solution (str) – The proposed solution text.
seed (int) – The per-environment seed.
y (int, optional) – The true label (0 for incorrect, 1 for correct). Only used if
replace_last_message_with_true_label
is set to True.ensure_last_message_is_assistant (bool, default=False) – Whether to ensure the last message is from the assistant, by removing messages from the user.
replace_last_message_with_true_label (bool, default=False) – Whether to replace the last message with ‘Decision: accept’ or ‘Decision: reject’ based on the true label. If this is set to True,
y
must be provided. Only used ifensure_last_message_is_assistant
is set to True.
- Returns:
chat_messages (list[PromptMessage]) – The chat messages ready to feed to the API.
- Raises:
AgentNotActiveInChannelError – If
ensure_last_message_is_assistant
is set to True and the agent is not active in the channel (i.e. there would be no messages in the chat history).
- _generate_dummy_response(chat_messages_prompt: list[PromptMessage]) str [source]#
Generate a dummy response to a chat prompt.
- Parameters:
chat_messages_prompt (list[PromptMessage]) – The chat messages prompt to generate a response to.
- Returns:
response (str) – The dummy response generated.
- _generate_next_message_and_decision(message_history: Annotated[ndarray[Any, dtype[_ScalarType_co]], 'round channel'], message_agent_id: Annotated[ndarray[Any, dtype[_ScalarType_co]], 'round channel'], raw_message_history: Annotated[ndarray[Any, dtype[_ScalarType_co]], 'round agent'], round_id: int, question: str, solution: str, seed: int, prover_stance: int) _ParsedChatCompletion [source]#
Generate the next message and decision for the agent, with retries.
This message takes a single message history and builds and runs the API request to generate the next action, which can be a message or a decision.
If the there is an error in the generation, this method will retry a number of times before raising an exception (detailed below).
- Parameters:
message_history (String[NDArray, "round channel"]) – The array of messages in the chat history.
message_agent_id (String[NDArray, "round channel"]) – The id of the agent who messaged at a round-channel pair.
raw_message_history (String[NDArray, "round agent"]) – The raw message generated by each model in each timestep.
round_id (int) – The current round number.
channel_name (str) – The name of the message channel.
question (str) – The problem text.
solution (str) – The proposed solution text.
seed (int) – The per-environment seed.
prover_stance (int) – The verdict that the prover is arguing for, where 0 means “reject” and 1 means “accept”. (Currently ignored.)
- Returns:
parsed_chat_completion (_ParsedChatCompletion) – The parsed chat completion output. See the
_ParsedChatCompletion
class for details.- Raises:
ContentFilterError – If the agent’s response is blocked by a content filter.
UnknownFinishReasonError – If the agent finishes generating for an unknown reason.
InvalidResponseError – If the agent generates a response in an invalid format.
InvalidDecisionError – If the agent generates an invalid decision (i.e. not accept or reject).
- _make_generation_api_call(chat_messages_prompt: list[PromptMessage]) tuple[str | None, Literal['stop', 'length', 'tool_calls', 'content_filter', 'function_call']] [source]#
Call the OpenAI API to generate the next message.
- Parameters:
chat_messages_prompt (list[PromptMessage]) – The message history to feed to the API. A list of dicts with keys “role” and “content”.
- Returns:
completion_text (str | None) – The text of the completion generated by the API.
finish_reason (Literal[“stop”, “length”, “tool_calls”, “content_filter”,)
”function_call”] – The reason for finishing the generation.
- build_fine_tune_dataset(timesteps: NestedArrayDict, ensure_last_message_is_assistant: bool = True, replace_verifier_guess_with_true_label: bool = False) list[dict[Literal['messages'], list[PromptMessage]]] [source]#
Build the dataset for fine-tuning the agent given sampled timesteps.
A ‘timestep’ is just a single step in the environment. This holds all information about the environment state, including past messages. Generally, one wants to fine-tune on the final timestep of a rollout, but we can fine-tune on any timestep.
This method generates a dataset of examples ready to pass to the fine-tune API.
- Parameters:
timesteps (NestedArrayDict) –
The sampled timesteps. This is usually the final step in each rollout, but it can be any step if we want to fine-tune on an initial segment of the rollout. A nested dictionary of arrays with keys:
”round” (batch): The current round number.
(“next”, “message_history”) (batch round channel): The history of messages exchanged between the agents in each channel.
(“next”, “message_agent_id”) (batch round channel): The id of the agent who messaged at a round-channel pair.
(“next”, “raw_message_history”) (batch round agent): The raw message generated by each model in each timestep.
”question” (batch): The problem text.
”solution” (batch): The proposed solution text.
”prover_stance” (batch): When randomizing the prover stance, the verdict that the prover is arguing for, where 0 means “reject” and 1 means “accept”.
ensure_last_message_is_assistant (bool, default=True) – Make sure the last message in each rollout comes from the assistant, by removing messages from the end until this is true. This makes sense when we want to do supervised fine-tuning, as the last message should be the assistant’s message.
replace_verifier_guess_with_true_label (bool, default=False) – Whether to replace the verifier’s guess with the true label. If this is set to True, and the agent is a verifier, the verifier’s guess will be replaced with either ‘Decision: accept’ or ‘Decision: reject’ based on the true label. Only makes sense when
ensure_last_message_is_assistant
is set to True.
- Returns:
fine_tune_dataset (list[dict[Literal[“messages”], list[PromptMessage]]]) – The dataset for fine-tuning the agent. This is a list of examples, where each example is a dictionary with key “messages”, whose value is a list of dictionaries with keys “role” and “content”.
- forward(data: NestedArrayDict, environment: PureTextEnvironment) NestedArrayDict [source]#
Forward pass through the agent policy head.
- Parameters:
data (NestedArrayDict) –
The input data. A nested dictionary of arrays with keys:
”round” (…): The current round number.
”seed” (…): The per-environment seed.
- ”message_history” (… round channel): The history of messages exchanged
between the agents in each channel.
- ”message_agent_id” (… round channel): The id of the agent who messaged
at a round-channel pair.
- ”raw_message_history” (… round agent): The raw message generated by each
model in each timestep.
”question” (…): The problem text.
”solution” (…): The proposed solution text.
”prover_stance” (…): When randomizing the prover stance, the verdict that the prover is arguing for, where 0 means “reject” and 1 means “accept”.
environment (PureTextEnvironment) – The environment in which the agent is running.
- Returns:
output_data (NestedArrayDict) – The output of running the agent policy head. This is a nested dictionary of arrays with keys:
”message” (… channel): The message output by the agent per channel.
- ”raw_message” (…): The raw message generated by the model, before
parsing.
- ”decision” (…): The decision made by the agent. This is either 0
(reject), 1 (accept) or 2 (no decision).
- ”retry_count” (… channel): The number of retries before the generation
was successful.
- ”token_limit” (… channel): A boolean array indicating whether the
generation was stopped because the maximum number of tokens was reached.
- get_state_dict() dict [source]#
Get the state of the agent part as a dict.
This method should be implemented by subclasses capable of saving their state.
- Returns:
state_dict (dict) – The state of the agent part.
- set_state(checkpoint: AgentState)[source]#
Set the state of the agent from a checkpoint.
This method should be overridden by subclasses to restore the state of the agent from a checkpoint.
- Parameters:
checkpoint (AgentCheckpoint) – The checkpoint to restore the state from.