nip.scenario_base.agents.AgentPart#
- class nip.scenario_base.agents.AgentPart(hyper_params: HyperParameters, settings: ExperimentSettings, agent_name: str, protocol_handler: ProtocolHandler)[source]#
Base class for all agent parts: bodies and heads.
The in and out keys are split into agent-level and environment-level keys. Agent-level keys are nested under “agents” in the environment’s state dict, while environment-level keys are at the top level.
- Parameters:
hyper_params (HyperParameters) – The parameters of the experiment.
settings (ExperimentSettings) – The settings of the experiment.
agent_name (str) – The name of the agent.
protocol_handler (ProtocolHandler) – The protocol handler for the experiment.
device (TorchDevice, optional) – The device to use for this agent part. If not given, the CPU is used.
attributes (Class)
----------------
agent_level_in_keys (Iterable[NestedKey]) – The keys required by the agent part whose values are per-agent (so in the environment’s state dict will be nested under “agents”).
env_level_in_keys (Iterable[NestedKey]) – The keys required by the agent part whose values are per-environment (so in the environment’s state dict will be at the top level).
agent_level_out_keys (Iterable[NestedKey]) – The keys produced by the agent part whose values are per-agent (so in the environment’s state dict will be nested under “agents”).
env_level_out_keys (Iterable[NestedKey]) – The keys produced by the agent part whose values are per-environment (so in the environment’s state dict will be at the top level).
Methods Summary
__init__
(hyper_params, settings, agent_name, ...)forward
(data)Forward pass through the agent part.
Get the state of the agent part as a dict.
set_state
(checkpoint)Set the state of the agent from a checkpoint.
Attributes
agent_id
The ID of the agent.
agent_level_in_keys
agent_level_out_keys
env_level_in_keys
env_level_out_keys
in_keys
The keys required by the module.
is_prover
Whether the agent is a prover.
is_verifier
Whether the agent is a verifier.
max_message_rounds
The maximum number of message rounds in the protocol.
num_visible_message_channels
The number of message channels visible to the agent.
out_keys
The keys produced by the module.
required_pretrained_models
The pretrained models used by the agent.
visible_message_channel_indices
The indices of the message channels visible to the agent.
visible_message_channel_mask
The mask for the message channels visible to the agent.
visible_message_channel_names
The names of the message channels visible to the agent.
Methods
- __init__(hyper_params: HyperParameters, settings: ExperimentSettings, agent_name: str, protocol_handler: ProtocolHandler)[source]#
- abstract forward(data: Any) Any [source]#
Forward pass through the agent part.
- Parameters:
data (Any) – The input to the agent part.
- Returns:
output (Any) – The output of the forward pass on the input.
- get_state_dict() dict [source]#
Get the state of the agent part as a dict.
This method should be implemented by subclasses capable of saving their state.
- Returns:
state_dict (dict) – The state of the agent part.
- set_state(checkpoint: AgentState)[source]#
Set the state of the agent from a checkpoint.
This method should be overridden by subclasses to restore the state of the agent from a checkpoint.
- Parameters:
checkpoint (AgentCheckpoint) – The checkpoint to restore the state from.