nip.scenario_base.agents.AgentPart#

class nip.scenario_base.agents.AgentPart(hyper_params: HyperParameters, settings: ExperimentSettings, agent_name: str, protocol_handler: ProtocolHandler)[source]#

Base class for all agent parts: bodies and heads.

The in and out keys are split into agent-level and environment-level keys. Agent-level keys are nested under “agents” in the environment’s state dict, while environment-level keys are at the top level.

Parameters:
  • hyper_params (HyperParameters) – The parameters of the experiment.

  • settings (ExperimentSettings) – The settings of the experiment.

  • agent_name (str) – The name of the agent.

  • protocol_handler (ProtocolHandler) – The protocol handler for the experiment.

  • device (TorchDevice, optional) – The device to use for this agent part. If not given, the CPU is used.

  • attributes (Class)

  • ----------------

  • agent_level_in_keys (Iterable[NestedKey]) – The keys required by the agent part whose values are per-agent (so in the environment’s state dict will be nested under “agents”).

  • env_level_in_keys (Iterable[NestedKey]) – The keys required by the agent part whose values are per-environment (so in the environment’s state dict will be at the top level).

  • agent_level_out_keys (Iterable[NestedKey]) – The keys produced by the agent part whose values are per-agent (so in the environment’s state dict will be nested under “agents”).

  • env_level_out_keys (Iterable[NestedKey]) – The keys produced by the agent part whose values are per-environment (so in the environment’s state dict will be at the top level).

Methods Summary

__init__(hyper_params, settings, agent_name, ...)

forward(data)

Forward pass through the agent part.

get_state_dict()

Get the state of the agent part as a dict.

set_state(checkpoint)

Set the state of the agent from a checkpoint.

Attributes

agent_id

The ID of the agent.

agent_level_in_keys

agent_level_out_keys

env_level_in_keys

env_level_out_keys

in_keys

The keys required by the module.

is_prover

Whether the agent is a prover.

is_verifier

Whether the agent is a verifier.

max_message_rounds

The maximum number of message rounds in the protocol.

num_visible_message_channels

The number of message channels visible to the agent.

out_keys

The keys produced by the module.

required_pretrained_models

The pretrained models used by the agent.

visible_message_channel_indices

The indices of the message channels visible to the agent.

visible_message_channel_mask

The mask for the message channels visible to the agent.

visible_message_channel_names

The names of the message channels visible to the agent.

Methods

__init__(hyper_params: HyperParameters, settings: ExperimentSettings, agent_name: str, protocol_handler: ProtocolHandler)[source]#
abstract forward(data: Any) Any[source]#

Forward pass through the agent part.

Parameters:

data (Any) – The input to the agent part.

Returns:

output (Any) – The output of the forward pass on the input.

get_state_dict() dict[source]#

Get the state of the agent part as a dict.

This method should be implemented by subclasses capable of saving their state.

Returns:

state_dict (dict) – The state of the agent part.

set_state(checkpoint: AgentState)[source]#

Set the state of the agent from a checkpoint.

This method should be overridden by subclasses to restore the state of the agent from a checkpoint.

Parameters:

checkpoint (AgentCheckpoint) – The checkpoint to restore the state from.