nip.graph_isomorphism.agents.GraphIsomorphismAgentHooks#
- class nip.graph_isomorphism.agents.GraphIsomorphismAgentHooks(gnn_output: callable | None = None, gnn_output_rounded: callable | None = None, pooled_gnn_output: callable | None = None, gnn_output_flatter: callable | None = None, transformer_input_initial: callable | None = None, pooled_feature: callable | None = None, message_feature: callable | None = None, transformer_input_pre_encoder: callable | None = None, transformer_input: callable | None = None, transformer_output_flatter: callable | None = None, graph_level_repr_pre_encoder: callable | None = None, node_level_repr_pre_encoder: callable | None = None, graph_level_repr: callable | None = None, node_level_repr: callable | None = None)[source]#
Holder for hooks to run at various points in the agent forward pass.
Methods Summary
__eq__
(other)Return self==value.
__init__
([gnn_output, gnn_output_rounded, ...])__repr__
()Return repr(self).
create_recorder_hooks
(storage[, per_agent])Create hooks to record the agent's output.
Attributes
gnn_output
gnn_output_flatter
gnn_output_rounded
graph_level_repr
graph_level_repr_pre_encoder
message_feature
node_level_repr
node_level_repr_pre_encoder
pooled_feature
pooled_gnn_output
transformer_input
transformer_input_initial
transformer_input_pre_encoder
transformer_output_flatter
Methods
- __eq__(other)#
Return self==value.
- __init__(gnn_output: callable | None = None, gnn_output_rounded: callable | None = None, pooled_gnn_output: callable | None = None, gnn_output_flatter: callable | None = None, transformer_input_initial: callable | None = None, pooled_feature: callable | None = None, message_feature: callable | None = None, transformer_input_pre_encoder: callable | None = None, transformer_input: callable | None = None, transformer_output_flatter: callable | None = None, graph_level_repr_pre_encoder: callable | None = None, node_level_repr_pre_encoder: callable | None = None, graph_level_repr: callable | None = None, node_level_repr: callable | None = None) None #
- __repr__()#
Return repr(self).
- classmethod create_recorder_hooks(storage: dict | TensorDict, per_agent: bool = True) AgentHooks [source]#
Create hooks to record the agent’s output.