nip.graph_isomorphism.agents.GraphIsomorphismSoloAgentHead#
- class nip.graph_isomorphism.agents.GraphIsomorphismSoloAgentHead(*args, **kwargs)[source]#
Solo agent head for the graph isomorphism task.
Solo agents try to solve the task on their own, without interacting with another agents.
Shapes
Input:
“graph_level_repr” (… 2 d_representation): The output graph-level representations.
Output:
“decision_logits” (… 2): A logit for each of the two options: guess that the graphs are isomorphic, or guess that the graphs are not isomorphic.
- Parameters:
hyper_params (HyperParameters) – The parameters of the experiment.
settings (ExperimentSettings) – The settings of the experiment.
agent_name (str) – The name of the agent.
protocol_handler (ProtocolHandler) – The protocol handler for the experiment.
device (TorchDevice, optional) – The device to use for this agent part. If not given, the CPU is used.
Methods Summary
__init__
(hyper_params, settings, agent_name, ...)Initialize internal Module state, shared by both nn.Module and ScriptModule.
_build_decider
([d_out, include_round])Build the module which produces a graph-pair level output.
_build_graph_level_mlp
(d_in, d_hidden, ...)Build an MLP which acts on the node-level representations.
_build_node_level_mlp
(d_in, d_hidden, d_out, ...)Build an MLP which acts on the node-level representations.
Initialise the module weights.
_run_masked_transformer
(transformer, ...)Run a transformer on graph and node representations, with masking.
_run_recorder_hook
(hooks, hook_name, output)forward
(body_output[, hooks])Run the solo agent head on the given body output.
Get the state of the agent part as a dict.
set_state
(checkpoint)Set the state of the agent from a checkpoint.
to
([device])Move the agent to the given device.
Attributes
T_destination
agent_id
The ID of the agent.
agent_level_in_keys
agent_level_out_keys
call_super_init
dump_patches
env_level_in_keys
env_level_out_keys
in_keys
The keys required by the module.
is_prover
Whether the agent is a prover.
is_verifier
Whether the agent is a verifier.
max_message_rounds
The maximum number of message rounds in the protocol.
num_visible_message_channels
The number of message channels visible to the agent.
out_keys
The keys produced by the module.
out_keys_source
required_pretrained_models
The pretrained models used by the agent.
visible_message_channel_indices
The indices of the message channels visible to the agent.
visible_message_channel_mask
The mask for the message channels visible to the agent.
visible_message_channel_names
The names of the message channels visible to the agent.
agent_params
training
Methods
- __init__(hyper_params: HyperParameters, settings: ExperimentSettings, agent_name: str, protocol_handler: ProtocolHandler)[source]#
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- _build_decider(d_out: int = 3, include_round: bool | None = None) TensorDictModule [source]#
Build the module which produces a graph-pair level output.
By default it is used to decide whether to continue exchanging messages. In this case it outputs a single triple of logits for the three options: guess that the graphs are not isomorphic, guess that the graphs are isomorphic, or continue exchanging messages.
- _build_graph_level_mlp(d_in: int, d_hidden: int, d_out: int, num_layers: int, include_round: bool = False, out_key: str = 'graph_level_mlp_output', squeeze: bool = False) TensorDictSequential [source]#
Build an MLP which acts on the node-level representations.
Shapes
Input:
“graph_level_repr”: (… 2 d_in)
Output:
“graph_level_mlp_output”: (… d_out)
- Parameters:
d_in (int) – The dimensionality of the graph-level representations. This will be multiplied by two, as the MLP takes as input the concatenation of the two graph-level representations.
d_hidden (int) – The dimensionality of the hidden layers.
d_out (int) – The dimensionality of the output.
num_layers (int) – The number of hidden layers in the MLP.
include_round (bool, default=False) – Whether to include the round number as a (one-hot encoded) input to the MLP.
out_key (str, default="graph_level_mlp_output") – The tensordict key to use for the output of the MLP.
squeeze (bool, default=False) – Whether to squeeze the output dimension. Only use this if the output dimension is 1.
- Returns:
node_level_mlp (TensorDictSequential) – The node-level MLP.
- _build_node_level_mlp(d_in: int, d_hidden: int, d_out: int, num_layers: int, out_key: str = 'node_level_mlp_output') TensorDictModule [source]#
Build an MLP which acts on the node-level representations.
Shapes
Input:
“node_level_repr”: (… 2 max_nodes d_in)
Output:
“node_level_mlp_output”: (… 2 max_nodes d_out)
- Parameters:
d_in (int) – The dimensionality of the input.
d_hidden (int) – The dimensionality of the hidden layers.
d_out (int) – The dimensionality of the output.
num_layers (int) – The number of hidden layers in the MLP.
out_key (str, default="node_level_mlp_output") – The tensordict key to use for the output of the MLP.
- Returns:
node_level_mlp (TensorDictModule) – The node-level MLP.
- classmethod _run_masked_transformer(transformer: TransformerEncoder, transformer_input: Float[Tensor, '... 2+2*node d_transformer'], node_mask: Float[Tensor, '... pair node']) Float[Tensor, '... 2+2*node d_transformer'] [source]#
Run a transformer on graph and node representations, with masking.
The input is expected to be the concatenation of the two graph-level representations and the node-level representations.
Attention is masked so that nodes only attend to nodes in the other graph and to the pooled representations. We also make sure that the transformer only attends to the actual nodes (and the pooled representations).
- Parameters:
transformer (torch.nn.TransformerEncoder) – The transformer module.
transformer_input (Float[Tensor, "... 2+2*node d_transformer"]) – The input to the transformer. This is expected to be the concatenation of the two graph-level representations and the node-level representations.
node_mask (Float[Tensor, "... pair node"]) – Which nodes actually exist.
- Returns:
transformer_output_flatter (Float[Tensor, “… 2+2*node d_transformer”]) – The output of the transformer.
- forward(body_output: TensorDict, hooks: GraphIsomorphismAgentHooks | None = None) TensorDict [source]#
Run the solo agent head on the given body output.
- Parameters:
body_output (TensorDict) –
The output of the body module. A tensor dict with keys:
”graph_level_repr” (… 2 d_representation): The output graph-level representations.
hooks (GraphIsomorphismAgentHooks, optional) – Hooks to run at various points in the agent forward pass.
- Returns:
out (TensorDict) – A tensor dict with keys:
”decision_logits” (… 2): A logit for each of the two options: guess that the graphs are isomorphic, or guess that the graphs are not isomorphic.
- get_state_dict() dict [source]#
Get the state of the agent part as a dict.
This method should be implemented by subclasses capable of saving their state.
- Returns:
state_dict (dict) – The state of the agent part.
- set_state(checkpoint: AgentState)[source]#
Set the state of the agent from a checkpoint.
This method should be overridden by subclasses to restore the state of the agent from a checkpoint.
- Parameters:
checkpoint (AgentCheckpoint) – The checkpoint to restore the state from.