nip.code_validation.rollout_analysis.OutOfTenRolloutAnalyser#
- class nip.code_validation.rollout_analysis.OutOfTenRolloutAnalyser(hyper_params: HyperParameters, settings: ExperimentSettings, protocol_handler: ProtocolHandler, model_name: str, *, use_dummy_api: bool = False)[source]#
Base class for rollout analyser which score from 0 to 10.
Each rollout is analysed by a language model to generate a score out of 10. This is done by first giving the system prompt, then the message history, and finally asking a question, which is done by the “supervisor” agent.
Methods Summary
__init__
(hyper_params, settings, ...[, ...])_build_chat_messages_prompt
(message_history, ...)Construct the chat history ready to feed to the API.
Generate a dummy response for the rollout analyser.
_generate_evaluation
(message_history, ...)Generate an evaluation for a rollout.
_get_score_from_response
(response)Get the score out of ten from language model response.
_make_generation_api_call
(chat_messages_prompt)Call the OpenAI API to generate the evaluation.
forward
(rollouts[, use_tqdm])Score the rollouts by running a language model on the message history.
Return an iterator over agent names and channel names to be analysed.
Attributes
client
The OpenAI client to use for interacting with the OpenAI API.
max_generation_retries
The number of times to retry if the model generates an invalid response.
supervisor_question
The question asked by the supervisor agent.
system_prompt_template_filename
The filename of the system prompt template.
text_to_int
name
Methods
- __init__(hyper_params: HyperParameters, settings: ExperimentSettings, protocol_handler: ProtocolHandler, model_name: str, *, use_dummy_api: bool = False)[source]#
- _build_chat_messages_prompt(message_history: ndarray[Any, dtype[_ScalarType_co]], message_agent_id: ndarray[Any, dtype[_ScalarType_co]], agent_name: str, channel_name: str, question: str, solution: str) list[dict[str, str]] [source]#
Construct the chat history ready to feed to the API.
- Parameters:
message_history (NDArray) – The list of messages in the chat history.
message_agent_id (NDArray) – The agent ID of the agent which sent each message in the message history.
agent_name (str) – The name of the agent being evaluated.
channel_name (str) – The name of the message channel.
question (str) – The coding question.
solution (str) – The proposed solution to the coding question.
- Returns:
chat_messages (list[dict[str, str]]) – The chat messages ready to feed to the API.
- _generate_dummy_response()[source]#
Generate a dummy response for the rollout analyser.
This is used when the use_dummy_api flag is set to True, to generate a dummy response to the API call.
- Returns:
response (str) – The dummy response.
- _generate_evaluation(message_history: ndarray[Any, dtype[_ScalarType_co]], message_agent_id: ndarray[Any, dtype[_ScalarType_co]], agent_name: str, channel_name: str, question: str, solution: str) Real | None [source]#
Generate an evaluation for a rollout.
- Parameters:
message_history (NDArray) – The history of messages exchanged between the agents in the channel.
message_agent_id (NDArray) – The agent ID of the agent which sent each message in the message history.
agent_name (str) – The name of the agent being evaluated.
channel_name (str) – The name of the message channel.
question (str) – The coding question.
solution (str) – The proposed solution to the coding question.
- Returns:
evaluation (Real | None) – The evaluation. None indicates that the evaluation could not be generated.
- _get_score_from_response(response: str) int [source]#
Get the score out of ten from language model response.
- Parameters:
response (str) – The response from the language model.
- Returns:
classification (int) – The score out of ten. This is an integer between 0 and 10.
- Raises:
UnparsableResponseError – If the response is not a valid response.
- _make_generation_api_call(chat_messages_prompt: list[dict[Literal['role', 'content', 'user'], str]]) tuple[str, Literal['stop', 'content_filter', 'length']] [source]#
Call the OpenAI API to generate the evaluation.
- Parameters:
chat_messages_prompt (list[dict[str, str]]) – The message history to feed to the API. A list of dicts with keys “role” and “content”.
- Returns:
completion_text (str) – The text of the completion generated by the API.
finish_reason (Literal[“stop”, “content_filter”, “length”]) – The reason for finishing the generation.
- forward(rollouts: NestedArrayDict, use_tqdm: bool = False) dict[tuple[str, str], MaskedArray] [source]#
Score the rollouts by running a language model on the message history.
- Parameters:
rollouts (NestedArrayDict) –
The sampled rollouts. A nested dictionary of arrays with keys:
”round” (… round): The current round number.
”message_history” (… round round channel): The history of messages exchanged between the agents in each channel.
”question” (… round): The coding question.
”solution” (… round): The proposed solution to the coding question.
”prover_stance” (…): When randomizing the prover stance, the verdict that the prover is arguing for, where 0 means “reject” and 1 means “accept”.
- Returns:
evaluations (dict[tuple[str, str], ma.MaskedArray]) – The evaluations. A dictionary indexed by agent name and channel name, where
evaluations[agent_name, channel_name]
is an array of scores of shape (…)