nip.graph_isomorphism.agents.GraphIsomorphismCombinedPolicyHead#

class nip.graph_isomorphism.agents.GraphIsomorphismCombinedPolicyHead(*args, **kwargs)[source]#

A module which combines the agent policy heads for the graph isomorphism task.

Shapes

Input:

  • (“agents”, “node_level_repr”) (… agent pair node feature): The output node-level representations.

  • (“agents”, “graph_level_repr”) (… agent pair feature): The output graph-level representations.

  • “round” (…): The current round number.

  • “node_mask” (… pair node): Which nodes actually exist.

  • “message” (… channel position pair node): The most recent message.

  • “ignore_message” (…): Whether to ignore the message

  • “decision_restriction” (…): The restriction on what decisions are allowed.

Output:

  • (“agents”, “node_selected_logits”) (… agent channel position 2*max_nodes): A logit for each node, indicating the probability that this node should be sent as a message to the verifier.

  • (“agents”, “main_message_logits”) (… agents channel position logit): The same as “node_selected_logits”.

  • (“agents”, “decision_logits”) (… agent 3): A logit for each of the three options: guess that the graphs are isomorphic, guess that the graphs are not isomorphic, or continue exchanging messages. d_linear_message_space) (optional): A logit for each linear message, indicating the probability that this linear message should be sent as a message to the verifier.

  • (“agents”, “linear_message_selected_logits”) (… agent channel position d_linear_message_space) (optional): A logit for each linear message, indicating the probability that this linear message should be sent as a message to the verifier.

Parameters:

Methods Summary

__init__(hyper_params, settings, ...)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

_expand_logits_to_all_channels(agent_name, ...)

Expand an agent's logits from its visible message channels to all.

_restrict_decisions(decision_restriction, ...)

Make sure the agent's decisions comply with the restrictions.

_restrict_input_to_visible_channels(...)

Restrict an agent's input to its visible message channels.

forward(body_output[, hooks])

Run the agent policy heads and combine their outputs.

Attributes

T_destination

additional_in_keys

additional_out_keys

call_super_init

device

The device used by the agent part.

dump_patches

excluded_in_keys

excluded_out_keys

in_keys

The keys required by the module.

out_keys

The keys produced by the module.

out_keys_source

training

Methods

__init__(hyper_params: HyperParameters, settings: ExperimentSettings, protocol_handler: ProtocolHandler, policy_heads: dict[str, AgentPolicyHead])[source]#

Initialize internal Module state, shared by both nn.Module and ScriptModule.

_expand_logits_to_all_channels(agent_name: str, logits: Tensor, shape_spec: str, fill_value: float = -1000000000.0) Tensor[source]#

Expand an agent’s logits from its visible message channels to all.

Agents only output messages for the channels they can see. This function expands the output to all channels, by filling in fill_value for the logits in the channels the agent cannot see.

Parameters:
  • agent_name (str) – The name of the agent.

  • logits (Tensor) – A tensor of output logits. This is a single key in the output of the agent’s forward pass.

  • shape_spec (str) – The shape of the output. This is a space-separated string of the dimensions of the output. One of these must be “channel”.

  • fill_value (float, default=-1e9) – The value to fill in for the channels the agent cannot see.

Returns:

expanded_logits (Tensor) – The output expanded to all channels. This has the same shape as logits, except that the channel dimension is the full set of message channels.

_restrict_decisions(decision_restriction: Int[Tensor, '...'], decision_logits: Float[Tensor, '... agents 3']) TensorDictBase[source]#

Make sure the agent’s decisions comply with the restrictions.

Parameters:
  • decision_restriction (Int[Tensor, "..."]) –

    The restrictions on the agents’ decisions. The possible values are:#

    • 0: The verifier can decide anything.

    • 1: The verifier can only decide to continue interacting.

    • 2: The verifier can only make a guess.

  • decision_logits (Float[Tensor, "... agents 3"]) – The logits for the agents’ decisions.

Returns:

decision_logits (Float[Tensor, “… agents 3”]) – The logits for the agents’ decisions, with the restricted decisions set to -1e9.

_restrict_input_to_visible_channels(agent_name: str, input_array: Tensor | ndarray[Any, dtype[_ScalarType_co]], shape_spec: str) Tensor[source]#

Restrict an agent’s input to its visible message channels.

Agents only receive messages from the channels they can see. This function restricts the input to the agent to only the visible message channels.

Parameters:
  • agent_name (str) – The name of the agent.

  • input_array (Tensor | NDArray) – The input array to the agent.

  • shape_spec (str) – The shape of the input. This is a space-separated string of the dimensions of the input. One of these must be “channel”.

Returns:

restricted_input (Tensor | NDArray) – The input restricted to the visible message channels.

forward(body_output: TensorDictBase, hooks: GraphIsomorphismAgentHooks | None = None) TensorDict[source]#

Run the agent policy heads and combine their outputs.

Parameters:
  • body_output (TensorDictBase) – The combined output of the agent bodies.

  • hooks (GraphIsomorphismAgentHooks, optional) – Hooks to run at various points in the agent forward pass.

Returns:

body_output (TensorDict) – The tensordict updated in place with the output of the policy heads.