nip.graph_isomorphism.agents.GraphIsomorphismAgentHooks#

class nip.graph_isomorphism.agents.GraphIsomorphismAgentHooks(gnn_output: callable | None = None, gnn_output_rounded: callable | None = None, pooled_gnn_output: callable | None = None, gnn_output_flatter: callable | None = None, transformer_input_initial: callable | None = None, pooled_feature: callable | None = None, message_feature: callable | None = None, transformer_input_pre_encoder: callable | None = None, transformer_input: callable | None = None, transformer_output_flatter: callable | None = None, graph_level_repr_pre_encoder: callable | None = None, node_level_repr_pre_encoder: callable | None = None, graph_level_repr: callable | None = None, node_level_repr: callable | None = None)[source]#

Holder for hooks to run at various points in the agent forward pass.

Methods Summary

__eq__(other)

Return self==value.

__init__([gnn_output, gnn_output_rounded, ...])

__repr__()

Return repr(self).

create_recorder_hooks(storage[, per_agent])

Create hooks to record the agent's output.

Attributes

gnn_output

gnn_output_flatter

gnn_output_rounded

graph_level_repr

graph_level_repr_pre_encoder

message_feature

node_level_repr

node_level_repr_pre_encoder

pooled_feature

pooled_gnn_output

transformer_input

transformer_input_initial

transformer_input_pre_encoder

transformer_output_flatter

Methods

__eq__(other)#

Return self==value.

__init__(gnn_output: callable | None = None, gnn_output_rounded: callable | None = None, pooled_gnn_output: callable | None = None, gnn_output_flatter: callable | None = None, transformer_input_initial: callable | None = None, pooled_feature: callable | None = None, message_feature: callable | None = None, transformer_input_pre_encoder: callable | None = None, transformer_input: callable | None = None, transformer_output_flatter: callable | None = None, graph_level_repr_pre_encoder: callable | None = None, node_level_repr_pre_encoder: callable | None = None, graph_level_repr: callable | None = None, node_level_repr: callable | None = None) None#
__repr__()#

Return repr(self).

classmethod create_recorder_hooks(storage: dict | TensorDict, per_agent: bool = True) AgentHooks[source]#

Create hooks to record the agent’s output.

Parameters:
  • storage (dict | TensorDict) – The dictionary to store the agent’s output in.

  • per_agent (bool, default=True) – Whether to store the output of each agent separately.

Returns:

hooks (AgentHooks) – The hooks to record the agent’s output.