nip.graph_isomorphism.agents.GraphIsomorphismCombinedBody#
- class nip.graph_isomorphism.agents.GraphIsomorphismCombinedBody(*args, **kwargs)[source]#
A module which combines the agent bodies for the graph isomorphism task.
Shapes
Input:
“round” (…): The current round number.
“x” (… round channel position pair node): The graph node features (message history)
“adjacency” (… pair node node): The adjacency matrices.
“message” (… channel position pair node), optional: The most recent message.
“node_mask” (… pair node): Which nodes actually exist.
“linear_message_history” : (… round channel position linear_message), optional: The linear message history, if using.
Output:
(“agents”, “node_level_repr”) (… agent pair max_nodes d_representation): The output node-level representations.
(“agents”, “graph_level_repr”) (… agent pair d_representation): The output graph-level representations.
- Parameters:
hyper_params (HyperParameters) – The parameters of the experiment.
settings (ExperimentSettings) – The settings of the experiment.
protocol_handler (ProtocolHandler) – The protocol handler for the experiment.
bodies (dict[str, GraphIsomorphismAgentBody]) – The agent bodies to combine.
Methods Summary
__init__(hyper_params, settings, ...)Initialize internal Module state, shared by both nn.Module and ScriptModule.
Restrict an agent's input to its visible message channels.
forward(data[, hooks])Run the agent bodies and combine their outputs.
Attributes
T_destinationadditional_in_keysadditional_out_keyscall_super_initdeviceThe device used by the agent part.
dump_patchesexcluded_in_keysexcluded_out_keysin_keysThe keys required by the module.
out_keysThe keys produced by the module.
out_keys_sourcetrainingMethods
- __init__(hyper_params: HyperParameters, settings: ExperimentSettings, protocol_handler: ProtocolHandler, bodies: dict[str, AgentBody])[source]#
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- _restrict_input_to_visible_channels(agent_name: str, input_array: Tensor | ndarray[Any, dtype[_ScalarType_co]], shape_spec: str) Tensor[source]#
Restrict an agent’s input to its visible message channels.
Agents only receive messages from the channels they can see. This function restricts the input to the agent to only the visible message channels.
- Parameters:
- Returns:
restricted_input (Tensor | NDArray) – The input restricted to the visible message channels.
- forward(data: TensorDictBase, hooks: GraphIsomorphismAgentHooks | None = None) TensorDict[source]#
Run the agent bodies and combine their outputs.
- Parameters:
data (TensorDictBase) –
The input data. A tensor dict with keys:
”round” (…): The current round number.
”x” (… round channel position pair node): The graph node features (message history)
”adjacency” (… pair node node): The adjacency matrices.
”message” (… channel position pair node), optional: The most recent message.
”node_mask” (… pair node): Which nodes actually exist.
”linear_message_history” : (… round channel position linear_message), optional: The linear message history, if using.
hooks (GraphIsomorphismAgentHooks, optional) – Hooks to run at various points in the agent forward pass.
- Returns:
data (TensorDict) – The tensordict updated in place with the output of the agent bodies.