nip.scenario_base.agents.PureTextWholeAgent#
- class nip.scenario_base.agents.PureTextWholeAgent(hyper_params: HyperParameters, settings: ExperimentSettings, agent_name: str, protocol_handler: ProtocolHandler)[source]#
Base class for whole agents which process text input and call APIs.
Methods Summary
__call__(data, environment)Run a forward pass through the agent, with some safety checks.
__init__(hyper_params, settings, agent_name, ...)build_fine_tune_dataset(rollouts)Build a dataset for fine-tuning the agent from sampled rollouts.
forward(data, environment)Forward pass through the agent.
Get the state of the agent part as a dict.
set_state(checkpoint)Set the state of the agent from a checkpoint.
Attributes
agent_idThe ID of the agent.
agent_level_in_keysagent_level_out_keysenv_level_in_keysenv_level_out_keysin_keysThe keys required by the module.
is_proverWhether the agent is a prover.
is_verifierWhether the agent is a verifier.
max_message_roundsThe maximum number of message rounds in the protocol.
num_visible_message_channelsThe number of message channels visible to the agent.
out_keysThe keys produced by the module.
required_pretrained_modelsThe pretrained models used by the agent.
shared_model_groupvisible_message_channel_indicesThe indices of the message channels visible to the agent.
visible_message_channel_maskThe mask for the message channels visible to the agent.
visible_message_channel_namesThe names of the message channels visible to the agent.
agent_paramsMethods
- async __call__(data: NestedArrayDict, environment: PureTextEnvironment) NestedArrayDict[source]#
Run a forward pass through the agent, with some safety checks.
- Parameters:
data (NestedArrayDict) – The input to the agent.
environment (PureTextEnvironment) – The environment the agent is interacting with.
- Returns:
output (NestedArrayDict) – The output of the forward pass on the input.
- __init__(hyper_params: HyperParameters, settings: ExperimentSettings, agent_name: str, protocol_handler: ProtocolHandler)[source]#
- abstract build_fine_tune_dataset(rollouts: NestedArrayDict) list[source]#
Build a dataset for fine-tuning the agent from sampled rollouts.
This method generates a dataset of examples ready to pass to the fine-tune API.
- Parameters:
rollouts (NestedArrayDict) – The sampled rollouts.
- Returns:
fine_tune_dataset (list) – The dataset for fine-tuning the agent.
- abstract async forward(data: NestedArrayDict, environment: PureTextEnvironment) NestedArrayDict[source]#
Forward pass through the agent.
- Parameters:
data (NestedArrayDict) – The input to the agent.
environment (PureTextEnvironment) – The environment the agent is interacting with.
- Returns:
output (NestedArrayDict) – The output of the forward pass on the input.
- get_state_dict() dict[source]#
Get the state of the agent part as a dict.
This method should be implemented by subclasses capable of saving their state.
- Returns:
state_dict (dict) – The state of the agent part.
- set_state(checkpoint: AgentState)[source]#
Set the state of the agent from a checkpoint.
This method should be overridden by subclasses to restore the state of the agent from a checkpoint.
- Parameters:
checkpoint (AgentCheckpoint) – The checkpoint to restore the state from.