nip.scenario_base.agents.PureTextWholeAgent#
- class nip.scenario_base.agents.PureTextWholeAgent(hyper_params: HyperParameters, settings: ExperimentSettings, agent_name: str, protocol_handler: ProtocolHandler)[source]#
Base class for whole agents which process text input and call APIs.
Methods Summary
__call__
(data, environment)Call self as a function.
__init__
(hyper_params, settings, agent_name, ...)build_fine_tune_dataset
(rollouts)Build a dataset for fine-tuning the agent from sampled rollouts.
forward
(data, environment)Forward pass through the agent.
Get the state of the agent part as a dict.
set_state
(checkpoint)Set the state of the agent from a checkpoint.
Attributes
agent_id
The ID of the agent.
agent_level_in_keys
agent_level_out_keys
env_level_in_keys
env_level_out_keys
in_keys
The keys required by the module.
is_prover
Whether the agent is a prover.
is_verifier
Whether the agent is a verifier.
max_message_rounds
The maximum number of message rounds in the protocol.
num_visible_message_channels
The number of message channels visible to the agent.
out_keys
The keys produced by the module.
required_pretrained_models
The pretrained models used by the agent.
shared_model_group
visible_message_channel_indices
The indices of the message channels visible to the agent.
visible_message_channel_mask
The mask for the message channels visible to the agent.
visible_message_channel_names
The names of the message channels visible to the agent.
agent_params
Methods
- __call__(data: NestedArrayDict, environment: PureTextEnvironment) NestedArrayDict [source]#
Call self as a function.
- __init__(hyper_params: HyperParameters, settings: ExperimentSettings, agent_name: str, protocol_handler: ProtocolHandler)[source]#
- abstract build_fine_tune_dataset(rollouts: NestedArrayDict) list [source]#
Build a dataset for fine-tuning the agent from sampled rollouts.
This method generates a dataset of examples ready to pass to the fine-tune API.
- Parameters:
rollouts (NestedArrayDict) – The sampled rollouts.
- Returns:
fine_tune_dataset (list) – The dataset for fine-tuning the agent.
- abstract forward(data: NestedArrayDict, environment: PureTextEnvironment) NestedArrayDict [source]#
Forward pass through the agent.
- Parameters:
data (NestedArrayDict) – The input to the agent.
environment (PureTextEnvironment) – The environment the agent is interacting with.
- Returns:
output (NestedArrayDict) – The output of the forward pass on the input.
- get_state_dict() dict [source]#
Get the state of the agent part as a dict.
This method should be implemented by subclasses capable of saving their state.
- Returns:
state_dict (dict) – The state of the agent part.
- set_state(checkpoint: AgentState)[source]#
Set the state of the agent from a checkpoint.
This method should be overridden by subclasses to restore the state of the agent from a checkpoint.
- Parameters:
checkpoint (AgentCheckpoint) – The checkpoint to restore the state from.