nip.graph_isomorphism.agents.GraphIsomorphismAgentValueHead#

class nip.graph_isomorphism.agents.GraphIsomorphismAgentValueHead(*args, **kwargs)[source]#

Value head for the graph isomorphism task.

Takes as input the output of the agent body and outputs a value function.

Shapes

Input:

  • “graph_level_repr” (… 2 d_representation): The output graph-level representations.

  • “round” (optional) (…): The current round number.

Output:

  • “value” (…): The estimated value for each batch item

Parameters:
  • hyper_params (HyperParameters) – The parameters of the experiment.

  • settings (ExperimentSettings) – The settings of the experiment.

  • agent_name (str) – The name of the agent.

  • protocol_handler (ProtocolHandler) – The protocol handler for the experiment.

  • device (TorchDevice, optional) – The device to use for this agent part. If not given, the CPU is used.

Methods Summary

__init__(hyper_params, settings, agent_name, ...)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

_build_decider([d_out, include_round])

Build the module which produces a graph-pair level output.

_build_graph_level_mlp(d_in, d_hidden, ...)

Build an MLP which acts on the node-level representations.

_build_mlp()

Build the module which computes the value function.

_build_node_level_mlp(d_in, d_hidden, d_out, ...)

Build an MLP which acts on the node-level representations.

_init_weights()

Initialise the module weights.

_run_masked_transformer(transformer, ...)

Run a transformer on graph and node representations, with masking.

_run_recorder_hook(hooks, hook_name, output)

forward(body_output[, hooks])

Run the value head on the given body output.

get_state_dict()

Get the state of the agent part as a dict.

set_state(checkpoint)

Set the state of the agent from a checkpoint.

to([device])

Move the agent to the given device.

Attributes

T_destination

agent_id

The ID of the agent.

agent_level_in_keys

agent_level_out_keys

call_super_init

dump_patches

env_level_in_keys

The environment-level input keys.

env_level_out_keys

in_keys

The keys required by the module.

is_prover

Whether the agent is a prover.

is_verifier

Whether the agent is a verifier.

max_message_rounds

The maximum number of message rounds in the protocol.

num_visible_message_channels

The number of message channels visible to the agent.

out_keys

The keys produced by the module.

out_keys_source

required_pretrained_models

The pretrained models used by the agent.

visible_message_channel_indices

The indices of the message channels visible to the agent.

visible_message_channel_mask

The mask for the message channels visible to the agent.

visible_message_channel_names

The names of the message channels visible to the agent.

agent_params

training

Methods

__init__(hyper_params: HyperParameters, settings: ExperimentSettings, agent_name: str, protocol_handler: ProtocolHandler)[source]#

Initialize internal Module state, shared by both nn.Module and ScriptModule.

_build_decider(d_out: int = 3, include_round: bool | None = None) TensorDictModule[source]#

Build the module which produces a graph-pair level output.

By default it is used to decide whether to continue exchanging messages. In this case it outputs a single triple of logits for the three options: guess that the graphs are not isomorphic, guess that the graphs are isomorphic, or continue exchanging messages.

Parameters:
  • d_out (int, default=3) – The dimensionality of the output.

  • include_round (bool, optional) – Whether to include the round number as a (one-hot encoded) input to the MLP. If not given, the value from the agent parameters is used.

Returns:

decider (TensorDictModule) – The decider module.

_build_graph_level_mlp(d_in: int, d_hidden: int, d_out: int, num_layers: int, include_round: bool = False, out_key: str = 'graph_level_mlp_output', squeeze: bool = False) TensorDictSequential[source]#

Build an MLP which acts on the node-level representations.

Shapes

Input:

  • “graph_level_repr”: (… 2 d_in)

Output:

  • “graph_level_mlp_output”: (… d_out)

Parameters:
  • d_in (int) – The dimensionality of the graph-level representations. This will be multiplied by two, as the MLP takes as input the concatenation of the two graph-level representations.

  • d_hidden (int) – The dimensionality of the hidden layers.

  • d_out (int) – The dimensionality of the output.

  • num_layers (int) – The number of hidden layers in the MLP.

  • include_round (bool, default=False) – Whether to include the round number as a (one-hot encoded) input to the MLP.

  • out_key (str, default="graph_level_mlp_output") – The tensordict key to use for the output of the MLP.

  • squeeze (bool, default=False) – Whether to squeeze the output dimension. Only use this if the output dimension is 1.

Returns:

node_level_mlp (TensorDictSequential) – The node-level MLP.

_build_mlp() TensorDictModule[source]#

Build the module which computes the value function.

Returns:

value_mlp (TensorDictModule) – The value module.

_build_node_level_mlp(d_in: int, d_hidden: int, d_out: int, num_layers: int, out_key: str = 'node_level_mlp_output') TensorDictModule[source]#

Build an MLP which acts on the node-level representations.

Shapes

Input:

  • “node_level_repr”: (… 2 max_nodes d_in)

Output:

  • “node_level_mlp_output”: (… 2 max_nodes d_out)

Parameters:
  • d_in (int) – The dimensionality of the input.

  • d_hidden (int) – The dimensionality of the hidden layers.

  • d_out (int) – The dimensionality of the output.

  • num_layers (int) – The number of hidden layers in the MLP.

  • out_key (str, default="node_level_mlp_output") – The tensordict key to use for the output of the MLP.

Returns:

node_level_mlp (TensorDictModule) – The node-level MLP.

_init_weights()[source]#

Initialise the module weights.

Should be called at the end of __init__

classmethod _run_masked_transformer(transformer: TransformerEncoder, transformer_input: Float[Tensor, '... 2+2*node d_transformer'], node_mask: Float[Tensor, '... pair node']) Float[Tensor, '... 2+2*node d_transformer'][source]#

Run a transformer on graph and node representations, with masking.

The input is expected to be the concatenation of the two graph-level representations and the node-level representations.

Attention is masked so that nodes only attend to nodes in the other graph and to the pooled representations. We also make sure that the transformer only attends to the actual nodes (and the pooled representations).

Parameters:
  • transformer (torch.nn.TransformerEncoder) – The transformer module.

  • transformer_input (Float[Tensor, "... 2+2*node d_transformer"]) – The input to the transformer. This is expected to be the concatenation of the two graph-level representations and the node-level representations.

  • node_mask (Float[Tensor, "... pair node"]) – Which nodes actually exist.

Returns:

transformer_output_flatter (Float[Tensor, “… 2+2*node d_transformer”]) – The output of the transformer.

_run_recorder_hook(hooks: AgentHooks | None, hook_name: str, output: Tensor | None)[source]#
forward(body_output: TensorDict, hooks: GraphIsomorphismAgentHooks | None = None) TensorDict[source]#

Run the value head on the given body output.

Parameters:
  • body_output (TensorDict) –

    The output of the body module. A tensor dict with keys:

    • ”graph_level_repr” (… 2 d_representation): The output graph-level representations.

  • hooks (GraphIsomorphismAgentHooks, optional) – Hooks to run at various points in the agent forward pass.

Returns:

value_out (TensorDict) – A tensor dict with keys:

  • ”value” (…): The estimated value for each batch item

get_state_dict() dict[source]#

Get the state of the agent part as a dict.

This method should be implemented by subclasses capable of saving their state.

Returns:

state_dict (dict) – The state of the agent part.

set_state(checkpoint: AgentState)[source]#

Set the state of the agent from a checkpoint.

This method should be overridden by subclasses to restore the state of the agent from a checkpoint.

Parameters:

checkpoint (AgentCheckpoint) – The checkpoint to restore the state from.

to(device: device | str | int | None = None)[source]#

Move the agent to the given device.

Parameters:

device (TorchDevice, optional) – The device to use. If not given, the CPU is used.