nip.graph_isomorphism.agents.GraphIsomorphismConstantAgentValueHead#
- class nip.graph_isomorphism.agents.GraphIsomorphismConstantAgentValueHead(*args, **kwargs)[source]#
- A constant value head for the graph isomorphism task. - Shapes - Input: - “graph_level_repr” (… 2 d_representation): The output graph-level representations. 
- “node_level_repr” (… 2 max_nodes d_representation): The output node-level representations. 
 - Output: - “value” (…): The ‘value’ for each batch item, which is a constant zero. 
 - Methods Summary - __init__(hyper_params, settings, agent_name, ...)- Initialize internal Module state, shared by both nn.Module and ScriptModule. - Initialise the module weights. - _run_recorder_hook(hooks, hook_name, output)- forward(body_output[, hooks])- Return a constant value. - Get the state of the agent part as a dict. - set_state(checkpoint)- Set the state of the agent from a checkpoint. - to(device)- Move the agent to the given device. - Attributes - T_destination- agent_id- The ID of the agent. - agent_level_in_keys- agent_level_out_keys- call_super_init- dump_patches- env_level_in_keys- env_level_out_keys- in_keys- The keys required by the module. - is_prover- Whether the agent is a prover. - is_verifier- Whether the agent is a verifier. - max_message_rounds- The maximum number of message rounds in the protocol. - num_visible_message_channels- The number of message channels visible to the agent. - out_keys- The keys produced by the module. - out_keys_source- required_pretrained_models- The pretrained models used by the agent. - visible_message_channel_indices- The indices of the message channels visible to the agent. - visible_message_channel_mask- The mask for the message channels visible to the agent. - visible_message_channel_names- The names of the message channels visible to the agent. - agent_params- training- Methods - __init__(hyper_params: HyperParameters, settings: ExperimentSettings, agent_name: str, protocol_handler: ProtocolHandler)[source]#
- Initialize internal Module state, shared by both nn.Module and ScriptModule. 
 - forward(body_output: TensorDict, hooks: GraphIsomorphismAgentHooks | None = None) TensorDict[source]#
- Return a constant value. - Parameters:
- body_output (TensorDict) – - The output of the body module. A tensor dict with keys: - ”graph_level_repr” (… 2 1): The output graph-level representations. 
- ”node_level_repr” (… 2 max_nodes 1): The output node-level representations. 
 
- hooks (GraphIsomorphismAgentHooks, optional) – Hooks to run at various points in the agent forward pass. 
 
- Returns:
- value_out (TensorDict) – A tensor dict with keys: - ”value” (…): The ‘value’ for each batch item, which is a constant zero. 
 
 
 - get_state_dict() dict[source]#
- Get the state of the agent part as a dict. - This method should be implemented by subclasses capable of saving their state. - Returns:
- state_dict (dict) – The state of the agent part. 
 
 - set_state(checkpoint: AgentState)[source]#
- Set the state of the agent from a checkpoint. - This method should be overridden by subclasses to restore the state of the agent from a checkpoint. - Parameters:
- checkpoint (AgentCheckpoint) – The checkpoint to restore the state from.