nip.graph_isomorphism.agents.GraphIsomorphismAgentPolicyHead#

class nip.graph_isomorphism.agents.GraphIsomorphismAgentPolicyHead(*args, **kwargs)[source]#

Agent policy head for the graph isomorphism task.

Takes as input the output of the agent body and outputs a policy distribution over the actions. Both agents select a node to send as a message, and the verifier also decides whether to guess that the graphs are isomorphic or not or to continue exchanging messages.

Shapes

Input:

  • “graph_level_repr” (… 2 d_representation): The output graph-level representations.

  • “node_level_repr” (… 2 max_nodes d_representation): The output node-level representations.

  • “round” (optional) (…): The current round number.

Output:

  • “node_selected_logits” (… channel position 2*max_nodes): A logit for each node, indicating the probability that this node should be sent as a message.

  • “decision_logits” (optional) (… 3): A logit for each of the three options: guess that the graphs are isomorphic, guess that the graphs are not isomorphic, or continue exchanging messages. Set to zeros when the decider is not present.

  • “linear_message_selected_logits” (… channel position linear_message) (optional): A logit for each linear message, indicating the probability that this linear message should be sent as a message.

Parameters:
  • hyper_params (HyperParameters) – The parameters of the experiment.

  • settings (ExperimentSettings) – The settings of the experiment.

  • agent_name (str) – The name of the agent.

  • protocol_handler (ProtocolHandler) – The protocol handler for the experiment.

  • device (TorchDevice, optional) – The device to use for this agent part. If not given, the CPU is used.

Methods Summary

__init__(hyper_params, settings, agent_name, ...)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

_build_decider([d_out, include_round])

Build the module which produces a graph-pair level output.

_build_graph_level_mlp(d_in, d_hidden, ...)

Build an MLP which acts on the node-level representations.

_build_linear_message_selector()

Build the module which selects which linear message to send.

_build_node_level_mlp(d_in, d_hidden, d_out, ...)

Build an MLP which acts on the node-level representations.

_build_node_selector()

Build the module which selects which node to send as a message.

_init_weights()

Initialise the module weights.

_run_manual_architecture(body_output[, hooks])

Run the manually specified algorithm for the agent and environment.

_run_masked_transformer(transformer, ...)

Run a transformer on graph and node representations, with masking.

_run_recorder_hook(hooks, hook_name, output)

forward(body_output[, hooks])

Run the policy head on the given body output.

get_state_dict()

Get the state of the agent part as a dict.

set_state(checkpoint)

Set the state of the agent from a checkpoint.

to([device])

Move the agent head to the given device.

Attributes

T_destination

agent_id

The ID of the agent.

agent_level_in_keys

agent_level_out_keys

The agent-level output keys.

call_super_init

dump_patches

env_level_in_keys

The environment-level input keys.

env_level_out_keys

has_decider

Whether the policy head has an output yielding a decision.

in_keys

The keys required by the module.

is_prover

Whether the agent is a prover.

is_verifier

Whether the agent is a verifier.

max_message_rounds

The maximum number of message rounds in the protocol.

num_visible_message_channels

The number of message channels visible to the agent.

out_keys

The keys produced by the module.

out_keys_source

required_pretrained_models

The pretrained models used by the agent.

visible_message_channel_indices

The indices of the message channels visible to the agent.

visible_message_channel_mask

The mask for the message channels visible to the agent.

visible_message_channel_names

The names of the message channels visible to the agent.

agent_params

training

Methods

__init__(hyper_params: HyperParameters, settings: ExperimentSettings, agent_name: str, protocol_handler: ProtocolHandler)[source]#

Initialize internal Module state, shared by both nn.Module and ScriptModule.

_build_decider(d_out: int = 3, include_round: bool | None = None) TensorDictModule[source]#

Build the module which produces a graph-pair level output.

By default it is used to decide whether to continue exchanging messages. In this case it outputs a single triple of logits for the three options: guess that the graphs are not isomorphic, guess that the graphs are isomorphic, or continue exchanging messages.

Parameters:
  • d_out (int, default=3) – The dimensionality of the output.

  • include_round (bool, optional) – Whether to include the round number as a (one-hot encoded) input to the MLP. If not given, the value from the agent parameters is used.

Returns:

decider (TensorDictModule) – The decider module.

_build_graph_level_mlp(d_in: int, d_hidden: int, d_out: int, num_layers: int, include_round: bool = False, out_key: str = 'graph_level_mlp_output', squeeze: bool = False) TensorDictSequential[source]#

Build an MLP which acts on the node-level representations.

Shapes

Input:

  • “graph_level_repr”: (… 2 d_in)

Output:

  • “graph_level_mlp_output”: (… d_out)

Parameters:
  • d_in (int) – The dimensionality of the graph-level representations. This will be multiplied by two, as the MLP takes as input the concatenation of the two graph-level representations.

  • d_hidden (int) – The dimensionality of the hidden layers.

  • d_out (int) – The dimensionality of the output.

  • num_layers (int) – The number of hidden layers in the MLP.

  • include_round (bool, default=False) – Whether to include the round number as a (one-hot encoded) input to the MLP.

  • out_key (str, default="graph_level_mlp_output") – The tensordict key to use for the output of the MLP.

  • squeeze (bool, default=False) – Whether to squeeze the output dimension. Only use this if the output dimension is 1.

Returns:

node_level_mlp (TensorDictSequential) – The node-level MLP.

_build_linear_message_selector() TensorDictModule[source]#

Build the module which selects which linear message to send.

Returns:

linear_message_selector (TensorDictModule) – The linear message selector module.

_build_node_level_mlp(d_in: int, d_hidden: int, d_out: int, num_layers: int, out_key: str = 'node_level_mlp_output') TensorDictModule[source]#

Build an MLP which acts on the node-level representations.

Shapes

Input:

  • “node_level_repr”: (… 2 max_nodes d_in)

Output:

  • “node_level_mlp_output”: (… 2 max_nodes d_out)

Parameters:
  • d_in (int) – The dimensionality of the input.

  • d_hidden (int) – The dimensionality of the hidden layers.

  • d_out (int) – The dimensionality of the output.

  • num_layers (int) – The number of hidden layers in the MLP.

  • out_key (str, default="node_level_mlp_output") – The tensordict key to use for the output of the MLP.

Returns:

node_level_mlp (TensorDictModule) – The node-level MLP.

_build_node_selector() TensorDictModule[source]#

Build the module which selects which node to send as a message.

Returns:

node_selector (TensorDictModule) – The node selector module.

_init_weights()[source]#

Initialise the module weights.

Should be called at the end of __init__

_run_manual_architecture(body_output: TensorDict, hooks: GraphIsomorphismAgentHooks | None = None) TensorDict[source]#

Run the manually specified algorithm for the agent and environment.

The verifier waits until the last round, selecting nodes at random. In the last round it guesses that the graphs are isomorphic if the graph-level representations are close enough, and that they are not isomorphic otherwise. When the round number is not provided, it guesses with probability 0.5.

Without shared reward, the prover selects the node according to its representation’s similarity to the representation the node selected by the verifier in the previous round.

With shared reward, the prover does this when its graph-level representations are close (in which case it believes the graphs are isomorphic) and selects a node whose representation is most dissimilar to the representation the node selected by the verifier in the previous round when its graph-level representations are far apart (in which case it believes the graphs are not isomorphic).

Parameters:
  • body_output (TensorDict) –

    The output of the body module. A tensor dict with keys:

    • ”graph_level_repr” (… 2 d_representation): The output graph-level representations.

    • ”node_level_repr” (… 2 max_nodes d_representation): The output node-level representations.

    • ”message” (… channel position 2 max_nodes): The most recent message in the channel.

    • ”round” (optional) (…): The current round number.

  • hooks (GraphIsomorphismAgentHooks, optional) – Hooks to run at various points in the agent forward pass.

Returns:

out (TensorDict) – A tensor dict with keys:

  • ”node_selected_logits” (… channel position 2*max_nodes): A logit for each node, indicating the probability that this node should be sent as a message to the verifier.

  • ”decision_logits” (… 3): A logit for each of the three options: guess that the graphs are isomorphic, guess that the graphs are not isomorphic, or continue exchanging messages. Set to zeros when the decider is not present.

classmethod _run_masked_transformer(transformer: TransformerEncoder, transformer_input: Float[Tensor, '... 2+2*node d_transformer'], node_mask: Float[Tensor, '... pair node']) Float[Tensor, '... 2+2*node d_transformer'][source]#

Run a transformer on graph and node representations, with masking.

The input is expected to be the concatenation of the two graph-level representations and the node-level representations.

Attention is masked so that nodes only attend to nodes in the other graph and to the pooled representations. We also make sure that the transformer only attends to the actual nodes (and the pooled representations).

Parameters:
  • transformer (torch.nn.TransformerEncoder) – The transformer module.

  • transformer_input (Float[Tensor, "... 2+2*node d_transformer"]) – The input to the transformer. This is expected to be the concatenation of the two graph-level representations and the node-level representations.

  • node_mask (Float[Tensor, "... pair node"]) – Which nodes actually exist.

Returns:

transformer_output_flatter (Float[Tensor, “… 2+2*node d_transformer”]) – The output of the transformer.

_run_recorder_hook(hooks: AgentHooks | None, hook_name: str, output: Tensor | None)[source]#
forward(body_output: TensorDict, hooks: GraphIsomorphismAgentHooks | None = None) TensorDict[source]#

Run the policy head on the given body output.

Runs the node selector module and the decider module if present.

Parameters:
  • body_output (TensorDict) –

    The output of the body module. A tensor dict with keys:

    • ”graph_level_repr” (… 2 d_representation): The output graph-level representations.

    • ”node_level_repr” (… 2 max_nodes d_representation): The output node-level representations.

    • ”message” (…): The most recent message from the other agent.

    • ”round” (optional) (…): The current round number.

  • hooks (GraphIsomorphismAgentHooks, optional) – Hooks to run at various points in the agent forward pass.

Returns:

out (TensorDict) – A tensor dict with keys:

  • ”node_selected_logits” (… channel position 2*max_nodes): A logit for each node, indicating the probability that this node should be sent as a message to the verifier.

  • ”decision_logits” (… 3): A logit for each of the three options: guess that the graphs are isomorphic, guess that the graphs are not isomorphic, or continue exchanging messages. Set to zeros when the decider is not present.

  • ”linear_message_selected_logits” (… channel position linear_message) (optional): A logit for each linear message, indicating the probability that this linear message should be sent as a message.

get_state_dict() dict[source]#

Get the state of the agent part as a dict.

This method should be implemented by subclasses capable of saving their state.

Returns:

state_dict (dict) – The state of the agent part.

set_state(checkpoint: AgentState)[source]#

Set the state of the agent from a checkpoint.

This method should be overridden by subclasses to restore the state of the agent from a checkpoint.

Parameters:

checkpoint (AgentCheckpoint) – The checkpoint to restore the state from.

to(device: device | str | int | None = None)[source]#

Move the agent head to the given device.

Parameters:

device (TorchDevice, optional) – The device to move the agent head to. If not given, the CPU is used.