nip.graph_isomorphism.agents.GraphIsomorphismAgentHooks#
- class nip.graph_isomorphism.agents.GraphIsomorphismAgentHooks(gnn_output: callable | None = None, gnn_output_rounded: callable | None = None, pooled_gnn_output: callable | None = None, gnn_output_flatter: callable | None = None, transformer_input_initial: callable | None = None, pooled_feature: callable | None = None, message_feature: callable | None = None, transformer_input_pre_encoder: callable | None = None, transformer_input: callable | None = None, transformer_output_flatter: callable | None = None, graph_level_repr_pre_encoder: callable | None = None, node_level_repr_pre_encoder: callable | None = None, graph_level_repr: callable | None = None, node_level_repr: callable | None = None)[source]#
Holder for hooks to run at various points in the agent forward pass.
Methods Summary
__eq__(other)Return self==value.
__init__([gnn_output, gnn_output_rounded, ...])__repr__()Return repr(self).
create_recorder_hooks(storage[, per_agent])Create hooks to record the agent's output.
Attributes
gnn_outputgnn_output_flattergnn_output_roundedgraph_level_reprgraph_level_repr_pre_encodermessage_featurenode_level_reprnode_level_repr_pre_encoderpooled_featurepooled_gnn_outputtransformer_inputtransformer_input_initialtransformer_input_pre_encodertransformer_output_flatterMethods
- __eq__(other)#
Return self==value.
- __init__(gnn_output: callable | None = None, gnn_output_rounded: callable | None = None, pooled_gnn_output: callable | None = None, gnn_output_flatter: callable | None = None, transformer_input_initial: callable | None = None, pooled_feature: callable | None = None, message_feature: callable | None = None, transformer_input_pre_encoder: callable | None = None, transformer_input: callable | None = None, transformer_output_flatter: callable | None = None, graph_level_repr_pre_encoder: callable | None = None, node_level_repr_pre_encoder: callable | None = None, graph_level_repr: callable | None = None, node_level_repr: callable | None = None) None#
- __repr__()#
Return repr(self).
- classmethod create_recorder_hooks(storage: dict | TensorDict, per_agent: bool = True) AgentHooks[source]#
Create hooks to record the agent’s output.