nip.graph_isomorphism.agents.GraphIsomorphismCombinedValueHead#
- class nip.graph_isomorphism.agents.GraphIsomorphismCombinedValueHead(*args, **kwargs)[source]#
A module which combines the agent value heads for the graph isomorphism task.
Shapes
Input:
(“agents”, “graph_level_repr”) (… agent d_representation): The output graph-level representations.
“round” (…): The current round number.
Output:
(“agents”, “value”) (… agent): The estimated value for each batch item
- Parameters:
hyper_params (HyperParameters) – The parameters of the experiment.
settings (ExperimentSettings) – The settings of the experiment.
protocol_handler (ProtocolHandler) – The protocol handler for the experiment.
value_heads (dict[str, GraphIsomorphismAgentValueHead]) – The agent value heads to combine.
Methods Summary
__init__
(hyper_params, settings, ...)Initialize internal Module state, shared by both nn.Module and ScriptModule.
Restrict an agent's input to its visible message channels.
forward
(head_output[, hooks])Run the agent value heads and combine their values.
Attributes
T_destination
additional_in_keys
additional_out_keys
call_super_init
device
The device used by the agent part.
dump_patches
excluded_in_keys
excluded_out_keys
in_keys
The keys required by the module.
out_keys
The keys produced by the module.
out_keys_source
training
Methods
- __init__(hyper_params: HyperParameters, settings: ExperimentSettings, protocol_handler: ProtocolHandler, value_heads: dict[str, AgentValueHead])[source]#
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- _restrict_input_to_visible_channels(agent_name: str, input_array: Tensor | ndarray[Any, dtype[_ScalarType_co]], shape_spec: str) Tensor [source]#
Restrict an agent’s input to its visible message channels.
Agents only receive messages from the channels they can see. This function restricts the input to the agent to only the visible message channels.
- Parameters:
- Returns:
restricted_input (Tensor | NDArray) – The input restricted to the visible message channels.
- forward(head_output: TensorDictBase, hooks: GraphIsomorphismAgentHooks | None = None) TensorDict [source]#
Run the agent value heads and combine their values.
- Parameters:
tensordict (TensorDictBase) –
The input to the value heads. Should contain the keys:
(“agents”, “graph_level_repr”): The node-level representation from the body.
hooks (GraphIsomorphismAgentHooks, optional) – Hooks to run at various points in the agent forward pass.
- Returns:
tensordict (TensorDict) – The tensordict update in place with the output of the value heads.