nip.graph_isomorphism.agents.GraphIsomorphismAgentPart#

class nip.graph_isomorphism.agents.GraphIsomorphismAgentPart(*args, **kwargs)[source]#

Base class for all graph isomorphism agent parts.

Parameters:
  • hyper_params (HyperParameters) – The parameters of the experiment.

  • settings (ExperimentSettings) – The settings of the experiment.

  • agent_name (str) – The name of the agent.

  • protocol_handler (ProtocolHandler) – The protocol handler for the experiment.

  • device (TorchDevice, optional) – The device to use for this agent part. If not given, the CPU is used.

Methods Summary

__init__(hyper_params, settings, agent_name, ...)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

_init_weights()

Initialise the module weights.

_run_masked_transformer(transformer, ...)

Run a transformer on graph and node representations, with masking.

_run_recorder_hook(hooks, hook_name, output)

forward(data)

Forward pass through the agent part.

get_state_dict()

Get the state of the agent part as a dict.

set_state(checkpoint)

Set the state of the agent from a checkpoint.

to(device)

Move the agent to the given device.

Attributes

T_destination

agent_id

The ID of the agent.

agent_level_in_keys

agent_level_out_keys

call_super_init

dump_patches

env_level_in_keys

env_level_out_keys

in_keys

The keys required by the module.

is_prover

Whether the agent is a prover.

is_verifier

Whether the agent is a verifier.

max_message_rounds

The maximum number of message rounds in the protocol.

num_visible_message_channels

The number of message channels visible to the agent.

out_keys

The keys produced by the module.

out_keys_source

required_pretrained_models

The pretrained models used by the agent.

visible_message_channel_indices

The indices of the message channels visible to the agent.

visible_message_channel_mask

The mask for the message channels visible to the agent.

visible_message_channel_names

The names of the message channels visible to the agent.

agent_params

training

Methods

__init__(hyper_params: HyperParameters, settings: ExperimentSettings, agent_name: str, protocol_handler: ProtocolHandler)[source]#

Initialize internal Module state, shared by both nn.Module and ScriptModule.

_init_weights()[source]#

Initialise the module weights.

Should be called at the end of __init__

classmethod _run_masked_transformer(transformer: TransformerEncoder, transformer_input: Float[Tensor, '... 2+2*node d_transformer'], node_mask: Float[Tensor, '... pair node']) Float[Tensor, '... 2+2*node d_transformer'][source]#

Run a transformer on graph and node representations, with masking.

The input is expected to be the concatenation of the two graph-level representations and the node-level representations.

Attention is masked so that nodes only attend to nodes in the other graph and to the pooled representations. We also make sure that the transformer only attends to the actual nodes (and the pooled representations).

Parameters:
  • transformer (torch.nn.TransformerEncoder) – The transformer module.

  • transformer_input (Float[Tensor, "... 2+2*node d_transformer"]) – The input to the transformer. This is expected to be the concatenation of the two graph-level representations and the node-level representations.

  • node_mask (Float[Tensor, "... pair node"]) – Which nodes actually exist.

Returns:

transformer_output_flatter (Float[Tensor, “… 2+2*node d_transformer”]) – The output of the transformer.

_run_recorder_hook(hooks: AgentHooks | None, hook_name: str, output: Tensor | None)[source]#
abstract forward(data: Any) Any[source]#

Forward pass through the agent part.

Parameters:

data (Any) – The input to the agent part.

Returns:

output (Any) – The output of the forward pass on the input.

get_state_dict() dict[source]#

Get the state of the agent part as a dict.

This method should be implemented by subclasses capable of saving their state.

Returns:

state_dict (dict) – The state of the agent part.

set_state(checkpoint: AgentState)[source]#

Set the state of the agent from a checkpoint.

This method should be overridden by subclasses to restore the state of the agent from a checkpoint.

Parameters:

checkpoint (AgentCheckpoint) – The checkpoint to restore the state from.

abstract to(device: device | str | int)[source]#

Move the agent to the given device.