nip.graph_isomorphism.agents.GraphIsomorphismRandomAgentPolicyHead#
- class nip.graph_isomorphism.agents.GraphIsomorphismRandomAgentPolicyHead(*args, **kwargs)[source]#
Policy head for the graph isomorphism task yielding a uniform distribution.
Shapes
Input:
“graph_level_repr” (… 2 d_representation): The output graph-level representations.
“node_level_repr” (… 2 max_nodes d_representation): The output node-level representations.
Output:
“node_selected_logits” (… channel position 2*max_nodes): A logit for each node, indicating the probability that this node should be sent as a message to the verifier.
“decision_logits” (… 3): A logit for each of the three options: guess that the graphs are isomorphic, guess that the graphs are not isomorphic, or continue exchanging messages.
“linear_message_selected_logits” (… channel position d_linear_message_space) (optional): A logit for each linear message, indicating the probability that this linear message should be sent as a message to the verifier.
Methods Summary
__init__
(hyper_params, settings, agent_name, ...)Initialize internal Module state, shared by both nn.Module and ScriptModule.
Initialise the module weights.
_run_recorder_hook
(hooks, hook_name, output)forward
(body_output[, hooks])Output a uniform distribution.
Get the state of the agent part as a dict.
set_state
(checkpoint)Set the state of the agent from a checkpoint.
to
(device)Move the agent to the given device.
Attributes
T_destination
agent_id
The ID of the agent.
agent_level_in_keys
agent_level_out_keys
The agent-level output keys.
call_super_init
dump_patches
env_level_in_keys
env_level_out_keys
has_decider
Whether the policy head has an output yielding a decision.
in_keys
The keys required by the module.
is_prover
Whether the agent is a prover.
is_verifier
Whether the agent is a verifier.
max_message_rounds
The maximum number of message rounds in the protocol.
num_visible_message_channels
The number of message channels visible to the agent.
out_keys
The keys produced by the module.
out_keys_source
required_pretrained_models
The pretrained models used by the agent.
visible_message_channel_indices
The indices of the message channels visible to the agent.
visible_message_channel_mask
The mask for the message channels visible to the agent.
visible_message_channel_names
The names of the message channels visible to the agent.
agent_params
training
Methods
- __init__(hyper_params: HyperParameters, settings: ExperimentSettings, agent_name: str, protocol_handler: ProtocolHandler)[source]#
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(body_output: TensorDict, hooks: GraphIsomorphismAgentHooks | None = None) TensorDict [source]#
Output a uniform distribution.
- Parameters:
body_output (TensorDict) –
The output of the body module. A tensor dict with keys:
”graph_level_repr” (… 2 d_representation): The output graph-level representations.
”node_level_repr” (… 2 max_nodes d_representation): The output node-level representations.
hooks (GraphIsomorphismAgentHooks, optional) – Hooks to run at various points in the agent forward pass (ignored).
- Returns:
out (TensorDict) – A tensor dict with keys (all zeros):
”node_selected_logits” (… channel position 2*max_nodes): A logit for each node, indicating the probability that this node should be sent as a message to the verifier.
”decision_logits” (… 3): A logit for each of the three options: continue exchanging messages, guess that the graphs are isomorphic, or guess that the graphs are not isomorphic. Set to zeros when the decider is not present.
”linear_message_selected_logits” (… channel position d_linear_message_space) (optional): A logit for each linear message, indicating the probability that this linear message should be sent as a message to the verifier.
- get_state_dict() dict [source]#
Get the state of the agent part as a dict.
This method should be implemented by subclasses capable of saving their state.
- Returns:
state_dict (dict) – The state of the agent part.
- set_state(checkpoint: AgentState)[source]#
Set the state of the agent from a checkpoint.
This method should be overridden by subclasses to restore the state of the agent from a checkpoint.
- Parameters:
checkpoint (AgentCheckpoint) – The checkpoint to restore the state from.