nip.code_validation.rollout_analysis.ProverRoleConformanceAnalyser#
- class nip.code_validation.rollout_analysis.ProverRoleConformanceAnalyser(hyper_params: HyperParameters, settings: ExperimentSettings, protocol_handler: ProtocolHandler, model_name: str, *, use_dummy_api: bool = False)[source]#
A watchdog to evaluate how well the prover(s) are conforming to their roles.
The watchdog uses a language model to evaluate the message histories.
- Parameters:
hyper_params (HyperParameters) – The parameters of the experiment.
settings (ExperimentSettings) – The instance-specific settings of the experiment, like device, logging, etc.
protocol_handler (ProtocolHandler) – The protocol handler for the experiment.
Methods Summary
__init__
(hyper_params, settings, ...[, ...])_build_chat_messages_prompt
(message_history, ...)Construct the chat history ready to feed to the API.
Generate a dummy response for the rollout analyser.
_generate_evaluation
(message_history, ...)Generate an evaluation for a rollout.
_get_score_from_response
(response)Get the binary classification from language model response.
_make_generation_api_call
(chat_messages_prompt)Call the OpenAI API to generate the evaluation.
forward
(rollouts[, use_tqdm])Score the rollouts by running a language model on the message history.
Get the relevant agents and channels for the analysis.
Attributes
client
The OpenAI client to use for interacting with the OpenAI API.
max_generation_retries
The number of times to retry if the model generates an invalid response.
name
supervisor_question
system_prompt_template_filename
protocol_handler
Methods
- __init__(hyper_params: HyperParameters, settings: ExperimentSettings, protocol_handler: ProtocolHandler, model_name: str, *, use_dummy_api: bool = False)[source]#
- _build_chat_messages_prompt(message_history: ndarray[Any, dtype[_ScalarType_co]], message_agent_id: ndarray[Any, dtype[_ScalarType_co]], agent_name: str, channel_name: str, question: str, solution: str) list[dict[str, str]] [source]#
Construct the chat history ready to feed to the API.
- Parameters:
message_history (NDArray) – The list of messages in the chat history.
message_agent_id (NDArray) – The agent ID of the agent which sent each message in the message history.
agent_name (str) – The name of the agent being evaluated.
channel_name (str) – The name of the message channel.
question (str) – The coding question.
solution (str) – The proposed solution to the coding question.
- Returns:
chat_messages (list[dict[str, str]]) – The chat messages ready to feed to the API.
- _generate_dummy_response()[source]#
Generate a dummy response for the rollout analyser.
This is used when the use_dummy_api flag is set to True, to generate a dummy response to the API call.
- Returns:
response (str) – The dummy response.
- _generate_evaluation(message_history: ndarray[Any, dtype[_ScalarType_co]], message_agent_id: ndarray[Any, dtype[_ScalarType_co]], agent_name: str, channel_name: str, question: str, solution: str) Real | None [source]#
Generate an evaluation for a rollout.
- Parameters:
message_history (NDArray) – The history of messages exchanged between the agents in the channel.
message_agent_id (NDArray) – The agent ID of the agent which sent each message in the message history.
agent_name (str) – The name of the agent being evaluated.
channel_name (str) – The name of the message channel.
question (str) – The coding question.
solution (str) – The proposed solution to the coding question.
- Returns:
evaluation (Real | None) – The evaluation. None indicates that the evaluation could not be generated.
- _get_score_from_response(response: str) Literal[0, 1] [source]#
Get the binary classification from language model response.
- Parameters:
response (str) – The response from the language model.
- Returns:
classification (Literal[0, 1]) – The binary classification.
- Raises:
UnparsableResponseError – If the response is not a valid response.
- _make_generation_api_call(chat_messages_prompt: list[dict[Literal['role', 'content', 'user'], str]]) tuple[str, Literal['stop', 'content_filter', 'length']] [source]#
Call the OpenAI API to generate the evaluation.
- Parameters:
chat_messages_prompt (list[dict[str, str]]) – The message history to feed to the API. A list of dicts with keys “role” and “content”.
- Returns:
completion_text (str) – The text of the completion generated by the API.
finish_reason (Literal[“stop”, “content_filter”, “length”]) – The reason for finishing the generation.
- forward(rollouts: NestedArrayDict, use_tqdm: bool = False) dict[tuple[str, str], MaskedArray] [source]#
Score the rollouts by running a language model on the message history.
- Parameters:
rollouts (NestedArrayDict) –
The sampled rollouts. A nested dictionary of arrays with keys:
”round” (… round): The current round number.
”message_history” (… round round channel): The history of messages exchanged between the agents in each channel.
”question” (… round): The coding question.
”solution” (… round): The proposed solution to the coding question.
”prover_stance” (…): When randomizing the prover stance, the verdict that the prover is arguing for, where 0 means “reject” and 1 means “accept”.
- Returns:
evaluations (dict[tuple[str, str], ma.MaskedArray]) – The evaluations. A dictionary indexed by agent name and channel name, where
evaluations[agent_name, channel_name]
is an array of scores of shape (…)