nip.code_validation.rollout_analysis.BinaryRolloutAnalyser#
- class nip.code_validation.rollout_analysis.BinaryRolloutAnalyser(hyper_params: HyperParameters, settings: ExperimentSettings, protocol_handler: ProtocolHandler, model_name: str, *, use_dummy_api: bool = False)[source]#
Base class for rollout analyser which yield a binary classification.
Each rollout is analysed by a language model to generate a binary classification. This is done by first giving the system prompt, then the message history, and finally asking a question, which is done by the “supervisor” agent.
Methods Summary
__init__
(hyper_params, settings, ...[, ...])_build_chat_messages_prompt
(message_history, ...)Construct the chat history ready to feed to the API.
_generate_evaluation
(message_history, ...)Generate an evaluation for a rollout.
_make_generation_api_call
(chat_messages_prompt)Call the OpenAI API to generate the evaluation.
forward
(rollouts[, use_tqdm])Classify the rollouts by running a language model on the message history.
get_classification_from_response
(response)Get the binary classification from language model response.
Return an iterator over agent names and channel names to be analysed.
Attributes
client
The OpenAI client to use for interacting with the OpenAI API.
max_generation_retries
supervisor_question
The question asked by the supervisor agent.
system_prompt_template_filename
The filename of the system prompt template.
name
Methods
- __init__(hyper_params: HyperParameters, settings: ExperimentSettings, protocol_handler: ProtocolHandler, model_name: str, *, use_dummy_api: bool = False)[source]#
- _build_chat_messages_prompt(message_history: ndarray[Any, dtype[_ScalarType_co]], message_agent_id: ndarray[Any, dtype[_ScalarType_co]], agent_name: str, channel_name: str, question: str, solution: str) list[dict[str, str]] [source]#
Construct the chat history ready to feed to the API.
- Parameters:
message_history (NDArray) – The list of messages in the chat history.
message_agent_id (NDArray) – The agent ID of the agent which sent each message in the message history.
agent_name (str) – The name of the agent being evaluated.
channel_name (str) – The name of the message channel.
question (str) – The coding question.
solution (str) – The proposed solution to the coding question.
- Returns:
chat_messages (list[dict[str, str]]) – The chat messages ready to feed to the API.
- _generate_evaluation(message_history: ndarray[Any, dtype[_ScalarType_co]], message_agent_id: ndarray[Any, dtype[_ScalarType_co]], agent_name: str, channel_name: str, question: str, solution: str) int | None [source]#
Generate an evaluation for a rollout.
- Parameters:
message_history (NDArray) – The history of messages exchanged between the agents in the channel.
message_agent_id (NDArray) – The agent ID of the agent which sent each message in the message history.
agent_name (str) – The name of the agent being evaluated.
channel_name (str) – The name of the message channel.
question (str) – The coding question.
solution (str) – The proposed solution to the coding question.
- Returns:
evaluation (int | None) – The evaluation. None indicates that the evaluation could not be generated.
- _make_generation_api_call(chat_messages_prompt: list[dict[Literal['role', 'content', 'user'], str]]) tuple[str, Literal['stop', 'content_filter', 'length']] [source]#
Call the OpenAI API to generate the evaluation.
- Parameters:
chat_messages_prompt (list[dict[str, str]]) – The message history to feed to the API. A list of dicts with keys “role” and “content”.
- Returns:
completion_text (str) – The text of the completion generated by the API.
finish_reason (Literal[“stop”, “content_filter”, “length”]) – The reason for finishing the generation.
- forward(rollouts: NestedArrayDict, use_tqdm: bool = False) dict[tuple[str, str], MaskedArray] [source]#
Classify the rollouts by running a language model on the message history.
Evaluations are either 0 or 1.
- Parameters:
rollouts (NestedArrayDict) –
The sampled rollouts. A nested dictionary of arrays with keys:
”round” (… round): The current round number.
”message_history” (… round round channel): The history of messages exchanged between the agents in each channel.
”question” (… round): The coding question.
”solution” (… round): The proposed solution to the coding question.
”prover_stance” (…): When randomizing the prover stance, the verdict that the prover is arguing for, where 0 means “reject” and 1 means “accept”.
- Returns:
evaluations (dict[tuple[str, str], ma.MaskedArray]) – The evaluations. A dictionary indexed by agent name and channel name, where
evaluations[agent_name, channel_name]
is a 0-1 array of evaluations of shape (…)
- get_classification_from_response(response: str) int [source]#
Get the binary classification from language model response.
- Parameters:
response (str) – The response from the language model.
- Returns:
classification (int) – The binary classification. Either 0 or 1.
- Raises:
InvalidResponseError – If the response is not a valid response.