nip.graph_isomorphism.agents.GraphIsomorphismAgentHooks#
- class nip.graph_isomorphism.agents.GraphIsomorphismAgentHooks(gnn_output: callable | None = None, gnn_output_rounded: callable | None = None, pooled_gnn_output: callable | None = None, gnn_output_flatter: callable | None = None, transformer_input_initial: callable | None = None, pooled_feature: callable | None = None, message_feature: callable | None = None, transformer_input_pre_encoder: callable | None = None, transformer_input: callable | None = None, transformer_output_flatter: callable | None = None, graph_level_repr_pre_encoder: callable | None = None, node_level_repr_pre_encoder: callable | None = None, graph_level_repr: callable | None = None, node_level_repr: callable | None = None)[source]#
- Holder for hooks to run at various points in the agent forward pass. - Methods Summary - __eq__(other)- Return self==value. - __init__([gnn_output, gnn_output_rounded, ...])- __repr__()- Return repr(self). - create_recorder_hooks(storage[, per_agent])- Create hooks to record the agent's output. - Attributes - gnn_output- gnn_output_flatter- gnn_output_rounded- graph_level_repr- graph_level_repr_pre_encoder- message_feature- node_level_repr- node_level_repr_pre_encoder- pooled_feature- pooled_gnn_output- transformer_input- transformer_input_initial- transformer_input_pre_encoder- transformer_output_flatter- Methods - __eq__(other)#
- Return self==value. 
 - __init__(gnn_output: callable | None = None, gnn_output_rounded: callable | None = None, pooled_gnn_output: callable | None = None, gnn_output_flatter: callable | None = None, transformer_input_initial: callable | None = None, pooled_feature: callable | None = None, message_feature: callable | None = None, transformer_input_pre_encoder: callable | None = None, transformer_input: callable | None = None, transformer_output_flatter: callable | None = None, graph_level_repr_pre_encoder: callable | None = None, node_level_repr_pre_encoder: callable | None = None, graph_level_repr: callable | None = None, node_level_repr: callable | None = None) None#
 - __repr__()#
- Return repr(self). 
 - classmethod create_recorder_hooks(storage: dict | TensorDict, per_agent: bool = True) AgentHooks[source]#
- Create hooks to record the agent’s output.