nip.code_validation.rollout_analysis.ProverRoleConformanceAnalyser#
- class nip.code_validation.rollout_analysis.ProverRoleConformanceAnalyser(hyper_params: HyperParameters, settings: ExperimentSettings, protocol_handler: ProtocolHandler, model_name: str, *, use_dummy_api: bool = False)[source]#
A watchdog to evaluate how well the prover(s) are conforming to their roles.
The watchdog uses a language model to evaluate the message histories.
- Parameters:
hyper_params (HyperParameters) – The parameters of the experiment.
settings (ExperimentSettings) – The instance-specific settings of the experiment, like device, logging, etc.
protocol_handler (ProtocolHandler) – The protocol handler for the experiment.
Methods Summary
__init__
(hyper_params, settings, ...[, ...])_build_chat_messages_prompt
(message_history, ...)Construct the chat history ready to feed to the API.
_generate_evaluation
(message_history, ...)Generate an evaluation for a rollout.
_make_generation_api_call
(chat_messages_prompt)Call the OpenAI API to generate the evaluation.
forward
(rollouts[, use_tqdm])Classify the rollouts by running a language model on the message history.
get_classification_from_response
(response)Get the binary classification from language model response.
Get the relevant agents and channels for the analysis.
Attributes
client
The OpenAI client to use for interacting with the OpenAI API.
max_generation_retries
name
supervisor_question
system_prompt_template_filename
protocol_handler
Methods
- __init__(hyper_params: HyperParameters, settings: ExperimentSettings, protocol_handler: ProtocolHandler, model_name: str, *, use_dummy_api: bool = False)[source]#
- _build_chat_messages_prompt(message_history: ndarray[Any, dtype[_ScalarType_co]], message_agent_id: ndarray[Any, dtype[_ScalarType_co]], agent_name: str, channel_name: str, question: str, solution: str) list[dict[str, str]] [source]#
Construct the chat history ready to feed to the API.
- Parameters:
message_history (NDArray) – The list of messages in the chat history.
message_agent_id (NDArray) – The agent ID of the agent which sent each message in the message history.
agent_name (str) – The name of the agent being evaluated.
channel_name (str) – The name of the message channel.
question (str) – The coding question.
solution (str) – The proposed solution to the coding question.
- Returns:
chat_messages (list[dict[str, str]]) – The chat messages ready to feed to the API.
- _generate_evaluation(message_history: ndarray[Any, dtype[_ScalarType_co]], message_agent_id: ndarray[Any, dtype[_ScalarType_co]], agent_name: str, channel_name: str, question: str, solution: str) int | None [source]#
Generate an evaluation for a rollout.
- Parameters:
message_history (NDArray) – The history of messages exchanged between the agents in the channel.
message_agent_id (NDArray) – The agent ID of the agent which sent each message in the message history.
agent_name (str) – The name of the agent being evaluated.
channel_name (str) – The name of the message channel.
question (str) – The coding question.
solution (str) – The proposed solution to the coding question.
- Returns:
evaluation (int | None) – The evaluation. None indicates that the evaluation could not be generated.
- _make_generation_api_call(chat_messages_prompt: list[dict[Literal['role', 'content', 'user'], str]]) tuple[str, Literal['stop', 'content_filter', 'length']] [source]#
Call the OpenAI API to generate the evaluation.
- Parameters:
chat_messages_prompt (list[dict[str, str]]) – The message history to feed to the API. A list of dicts with keys “role” and “content”.
- Returns:
completion_text (str) – The text of the completion generated by the API.
finish_reason (Literal[“stop”, “content_filter”, “length”]) – The reason for finishing the generation.
- forward(rollouts: NestedArrayDict, use_tqdm: bool = False) dict[tuple[str, str], MaskedArray] [source]#
Classify the rollouts by running a language model on the message history.
Evaluations are either 0 or 1.
- Parameters:
rollouts (NestedArrayDict) –
The sampled rollouts. A nested dictionary of arrays with keys:
”round” (… round): The current round number.
”message_history” (… round round channel): The history of messages exchanged between the agents in each channel.
”question” (… round): The coding question.
”solution” (… round): The proposed solution to the coding question.
”prover_stance” (…): When randomizing the prover stance, the verdict that the prover is arguing for, where 0 means “reject” and 1 means “accept”.
- Returns:
evaluations (dict[tuple[str, str], ma.MaskedArray]) – The evaluations. A dictionary indexed by agent name and channel name, where
evaluations[agent_name, channel_name]
is a 0-1 array of evaluations of shape (…)
- get_classification_from_response(response: str) int [source]#
Get the binary classification from language model response.
- Parameters:
response (str) – The response from the language model.
- Returns:
classification (int) – The binary classification. Either 0 or 1.
- Raises:
InvalidResponseError – If the response is not a valid response.