- class nip.scenario_base.agents.AgentPart(hyper_params: HyperParameters, settings: ExperimentSettings, agent_name: str, protocol_handler: ProtocolHandler)[source]#
Base class for all agent parts: bodies and heads.
The in and out keys are split into agent-level and environment-level keys. Agent-level keys are nested under “agents” in the environment’s state dict, while environment-level keys are at the top level.
- Parameters:
hyper_params (HyperParameters) – The parameters of the experiment.
settings (ExperimentSettings) – The settings of the experiment.
agent_name (str) – The name of the agent.
protocol_handler (ProtocolHandler) – The protocol handler for the experiment.
device (TorchDevice, optional) – The device to use for this agent part. If not given, the CPU is used.
attributes (Class)
agent_level_in_keys (Iterable[NestedKey]) – The keys required by the agent part whose values are per-agent (so in the environment’s state dict will be nested under “agents”).
env_level_in_keys (Iterable[NestedKey]) – The keys required by the agent part whose values are per-environment (so in the environment’s state dict will be at the top level).
agent_level_out_keys (Iterable[NestedKey]) – The keys produced by the agent part whose values are per-agent (so in the environment’s state dict will be nested under “agents”).
env_level_out_keys (Iterable[NestedKey]) – The keys produced by the agent part whose values are per-environment (so in the environment’s state dict will be at the top level).
Methods Summary
(hyper_params, settings, agent_name, ...)forward
(data)Forward pass through the agent part.
Get the state of the agent part as a dict.
(checkpoint)Set the state of the agent from a checkpoint.
The ID of the agent.
The keys required by the module.
Whether the agent is a prover.
Whether the agent is a verifier.
The maximum number of message rounds in the protocol.
The number of message channels visible to the agent.
The keys produced by the module.
The pretrained models used by the agent.
The indices of the message channels visible to the agent.
The mask for the message channels visible to the agent.
The names of the message channels visible to the agent.
- __init__(hyper_params: HyperParameters, settings: ExperimentSettings, agent_name: str, protocol_handler: ProtocolHandler)[source]#
- abstract forward(data: Any) Any [source]#
Forward pass through the agent part.
- Parameters:
data (Any) – The input to the agent part.
- Returns:
output (Any) – The output of the forward pass on the input.
- get_state_dict() dict [source]#
Get the state of the agent part as a dict.
This method should be implemented by subclasses capable of saving their state.
- Returns:
state_dict (dict) – The state of the agent part.
- set_state(checkpoint: AgentState)[source]#
Set the state of the agent from a checkpoint.
This method should be overridden by subclasses to restore the state of the agent from a checkpoint.
- Parameters:
checkpoint (AgentCheckpoint) – The checkpoint to restore the state from.