nip.graph_isomorphism.agents.GraphIsomorphismAgentBody#
- class nip.graph_isomorphism.agents.GraphIsomorphismAgentBody(*args, **kwargs)[source]#
Agent body for the graph isomorphism task.
Takes as input a pair of graphs, message history and the most recent message and outputs node-level and graph-level representations.
Shapes
Input:
“x” (… round channel position pair node): The graph node features (message history)
“adjacency” (… pair node node): The graph adjacency matrices
“message” (… channel position pair node), optional: The most recent message from the other agent
“node_mask” (… pair node): Which nodes actually exist
“ignore_message” (…), optional: Whether to ignore any values in “message”. For example, in the first round the there is no message, and the “message” field is set to a dummy value.
“linear_message_history” : (… round channel position linear_message), optional: The linear message history, if using
Output:
“graph_level_repr” (… 2 d_representation): The output graph-level representations.
“node_level_repr” (… 2 max_nodes d_representation): The output node-level representations.
- Parameters:
hyper_params (HyperParameters) – The parameters of the experiment.
settings (ExperimentSettings) – The settings of the experiment.
agent_name (str) – The name of the agent.
protocol_handler (ProtocolHandler) – The protocol handler for the experiment.
device (TorchDevice, optional) – The device to use for this agent part. If not given, the CPU is used.
Methods Summary
__init__
(hyper_params, settings, agent_name, ...)Initialize internal Module state, shared by both nn.Module and ScriptModule.
Build a pooling layer which computes the graph-level representation.
Build the GNN module for an agent.
Build the encoder layer which translates the GNN output to transformer input.
_build_representation_encoder
(d_input)Build the encoder layer which translates to the representation space.
Build the transformer module for an agent.
Initialise the module weights.
_run_manual_architecture
(data[, hooks])Run the body part of the manual architecture.
_run_masked_transformer
(transformer, ...)Run a transformer on graph and node representations, with masking.
_run_recorder_hook
(hooks, hook_name, output)forward
(data[, hooks])Obtain graph-level and node-level representations by running components.
Get the state of the agent part as a dict.
set_state
(checkpoint)Set the state of the agent from a checkpoint.
to
([device])Move the agent body to a new device.
Attributes
T_destination
agent_id
The ID of the agent.
agent_level_in_keys
agent_level_out_keys
call_super_init
d_gnn_out
The dimensionality of the GNN output after the stream and feature dims.
dump_patches
env_level_in_keys
The environment-level input keys for the agent body.
env_level_out_keys
in_keys
The keys required by the module.
is_prover
Whether the agent is a prover.
is_verifier
Whether the agent is a verifier.
max_message_rounds
The maximum number of message rounds in the protocol.
num_visible_message_channels
The number of message channels visible to the agent.
out_keys
The keys produced by the module.
out_keys_source
required_pretrained_models
The pretrained models used by the agent.
visible_message_channel_indices
The indices of the message channels visible to the agent.
visible_message_channel_mask
The mask for the message channels visible to the agent.
visible_message_channel_names
The names of the message channels visible to the agent.
agent_params
training
Methods
- __init__(hyper_params: HyperParameters, settings: ExperimentSettings, agent_name: str, protocol_handler: ProtocolHandler)[source]#
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- _build_global_pooling() Sequential [source]#
Build a pooling layer which computes the graph-level representation.
The module consists of a global sum pooling layer, an optional batch norm layer, a paired Gaussian noise layer and an optional pair invariant pooling layer.
Shapes
Input:
“gnn_repr” (… pair node feature*stream): The input graph node features
Output:
“pooled_gnn_output” (… pair feature*stream): The output graph-level representation
- Returns:
global_pooling (torch.nn.Sequential) – The global pooling module.
- _build_gnn() TensorDictSequential [source]#
Build the GNN module for an agent.
Shapes
Input:
“gnn_repr” (… stream pair node feature): The input graph node features
“adjacency” (… stream pair node node): The graph adjacency matrices
Output:
“gnn_repr” (… stream pair node feature): The output graph node features
- Returns:
gnn (TensorDictSequential) – The GNN module, which takes as input a TensorDict with keys “gnn_repr”, “adjacency” and “node_mask”.
- _build_gnn_transformer_encoder() Linear [source]#
Build the encoder layer which translates the GNN output to transformer input.
This is a simple linear layer, where the number of input features is normally
d_gnn + 3
, where the extra features encode which graph-level representation the token is, if any and whether a node is in the most recent message from the other agent. When we are using a linear message space, the number of input features is increased by the number of rounds times the number of message features.- Returns:
gnn_transformer_encoder (torch.nn.Linear) – The encoder module
- _build_representation_encoder(d_input: int) Linear [source]#
Build the encoder layer which translates to the representation space.
This is a simple linear layer ensures that the number of output features is
hyper_params.d_representation
.- Parameters:
d_input (int) – The number of input features.
- Returns:
representation_encoder (torch.nn.Linear) – The encoder module
- _build_transformer() TransformerEncoder [source]#
Build the transformer module for an agent.
- Returns:
transformer (torch.nn.TransformerEncoder) – The transformer module.
- _run_manual_architecture(data: TensorDictBase, hooks: GraphIsomorphismAgentHooks | None = None) TensorDict [source]#
Run the body part of the manual architecture.
The verifier symmetrises the message history so that the information in round
2i
is the same as the information in round2i + 1
, while the prover ignores the message history completely.The message history is then run through a linear layer to make it the right size, then run through a GNN to get the node-level representations. The graph-level representations are then obtained by summing the node-level representations.
- Parameters:
data (TensorDictBase) –
The data to run the GNN and transformer on. A TensorDictBase with keys:
”x_rearranged” (… pair node round): The graph node features (message history) rearranged to but the round dimension at the end
”adjacency” (… pair node node): The graph adjacency matrices
”message” (…): The most recent message from the other agent
”node_mask” (… pair node): Which nodes actually exist or a GraphIsomorphism data object.
”ignore_message” (…): Whether to ignore any values in “message”. For example, in the first round the there is no message, and the “message” field is set to a dummy value.
hooks (GraphIsomorphismAgentHooks, optional) – Hooks to run at various points in the agent forward pass.
- Returns:
out (TensorDict) – A tensor dict with keys:
”graph_level_repr” (… pair d_representation): The output graph-level representations.
”node_level_repr” (… pair max_nodes d_representation): The output node-level representations.
- classmethod _run_masked_transformer(transformer: TransformerEncoder, transformer_input: Float[Tensor, '... 2+2*node d_transformer'], node_mask: Float[Tensor, '... pair node']) Float[Tensor, '... 2+2*node d_transformer'] [source]#
Run a transformer on graph and node representations, with masking.
The input is expected to be the concatenation of the two graph-level representations and the node-level representations.
Attention is masked so that nodes only attend to nodes in the other graph and to the pooled representations. We also make sure that the transformer only attends to the actual nodes (and the pooled representations).
- Parameters:
transformer (torch.nn.TransformerEncoder) – The transformer module.
transformer_input (Float[Tensor, "... 2+2*node d_transformer"]) – The input to the transformer. This is expected to be the concatenation of the two graph-level representations and the node-level representations.
node_mask (Float[Tensor, "... pair node"]) – Which nodes actually exist.
- Returns:
transformer_output_flatter (Float[Tensor, “… 2+2*node d_transformer”]) – The output of the transformer.
- forward(data: TensorDictBase, hooks: GraphIsomorphismAgentHooks | None = None) TensorDict [source]#
Obtain graph-level and node-level representations by running components.
Runs the GNN, pools the output, puts everything through a linear encoder, then runs the transformer on this.
- Parameters:
data (TensorDictBase) –
The data to run the GNN and transformer on. A TensorDictBase with keys:
”x” (… round channel position pair node): The graph node features (message history)
”adjacency” (… pair node node): The graph adjacency matrices
”message” (… channel position pair node), optional: The most recent message from the other agent
”node_mask” (… pair node), optional: Which nodes actually exist or a GraphIsomorphism data object.
”ignore_message” (…), optional: Whether to ignore any values in “message”. For example, in the first round the there is no message, and the “message” field is set to a dummy value.
”linear_message_history” : (… round channel position linear_message), optional: The linear message history, if using.
hooks (GraphIsomorphismAgentHooks, optional) – Hooks to run at various points in the agent forward pass.
- Returns:
out (TensorDict) – A tensor dict with keys:
”graph_level_repr” (… pair d_representation): The output graph-level representations.
”node_level_repr” (… pair max_nodes d_representation): The output node-level representations.
- get_state_dict() dict [source]#
Get the state of the agent part as a dict.
This method should be implemented by subclasses capable of saving their state.
- Returns:
state_dict (dict) – The state of the agent part.
- set_state(checkpoint: AgentState)[source]#
Set the state of the agent from a checkpoint.
This method should be overridden by subclasses to restore the state of the agent from a checkpoint.
- Parameters:
checkpoint (AgentCheckpoint) – The checkpoint to restore the state from.