nip.graph_isomorphism.agents.GraphIsomorphismAgentPart#
- class nip.graph_isomorphism.agents.GraphIsomorphismAgentPart(*args, **kwargs)[source]#
Base class for all graph isomorphism agent parts.
- Parameters:
hyper_params (HyperParameters) – The parameters of the experiment.
settings (ExperimentSettings) – The settings of the experiment.
agent_name (str) – The name of the agent.
protocol_handler (ProtocolHandler) – The protocol handler for the experiment.
device (TorchDevice, optional) – The device to use for this agent part. If not given, the CPU is used.
Methods Summary
__init__(hyper_params, settings, agent_name, ...)Initialize internal Module state, shared by both nn.Module and ScriptModule.
Initialise the module weights.
_run_masked_transformer(transformer, ...)Run a transformer on graph and node representations, with masking.
_run_recorder_hook(hooks, hook_name, output)forward(data)Forward pass through the agent part.
Get the state of the agent part as a dict.
set_state(checkpoint)Set the state of the agent from a checkpoint.
to(device)Move the agent to the given device.
Attributes
T_destinationagent_idThe ID of the agent.
agent_level_in_keysagent_level_out_keyscall_super_initdump_patchesenv_level_in_keysenv_level_out_keysin_keysThe keys required by the module.
is_proverWhether the agent is a prover.
is_verifierWhether the agent is a verifier.
max_message_roundsThe maximum number of message rounds in the protocol.
num_visible_message_channelsThe number of message channels visible to the agent.
out_keysThe keys produced by the module.
out_keys_sourcerequired_pretrained_modelsThe pretrained models used by the agent.
visible_message_channel_indicesThe indices of the message channels visible to the agent.
visible_message_channel_maskThe mask for the message channels visible to the agent.
visible_message_channel_namesThe names of the message channels visible to the agent.
agent_paramstrainingMethods
- __init__(hyper_params: HyperParameters, settings: ExperimentSettings, agent_name: str, protocol_handler: ProtocolHandler)[source]#
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- classmethod _run_masked_transformer(transformer: TransformerEncoder, transformer_input: Float[Tensor, '... 2+2*node d_transformer'], node_mask: Float[Tensor, '... pair node']) Float[Tensor, '... 2+2*node d_transformer'][source]#
Run a transformer on graph and node representations, with masking.
The input is expected to be the concatenation of the two graph-level representations and the node-level representations.
Attention is masked so that nodes only attend to nodes in the other graph and to the pooled representations. We also make sure that the transformer only attends to the actual nodes (and the pooled representations).
- Parameters:
transformer (torch.nn.TransformerEncoder) – The transformer module.
transformer_input (Float[Tensor, "... 2+2*node d_transformer"]) – The input to the transformer. This is expected to be the concatenation of the two graph-level representations and the node-level representations.
node_mask (Float[Tensor, "... pair node"]) – Which nodes actually exist.
- Returns:
transformer_output_flatter (Float[Tensor, “… 2+2*node d_transformer”]) – The output of the transformer.
- abstract forward(data: Any) Any[source]#
Forward pass through the agent part.
- Parameters:
data (Any) – The input to the agent part.
- Returns:
output (Any) – The output of the forward pass on the input.
- get_state_dict() dict[source]#
Get the state of the agent part as a dict.
This method should be implemented by subclasses capable of saving their state.
- Returns:
state_dict (dict) – The state of the agent part.
- set_state(checkpoint: AgentState)[source]#
Set the state of the agent from a checkpoint.
This method should be overridden by subclasses to restore the state of the agent from a checkpoint.
- Parameters:
checkpoint (AgentCheckpoint) – The checkpoint to restore the state from.