nip.image_classification.agents.ImageClassificationSoloAgentHead#
- class nip.image_classification.agents.ImageClassificationSoloAgentHead(*args, **kwargs)[source]#
- Solo agent head for the image classification task. - Solo agents try to solve the task on their own, without interacting with another agents. - Shapes - Input: - “image_level_repr” (… d_representation): The output image-level representations. 
 - Output: - “decision_logits” (… 2): A logit for each of the two options: guess that the graphs are isomorphic, or guess that the graphs are not isomorphic. 
 - Methods Summary - __init__(hyper_params, settings, agent_name, ...)- Initialize internal Module state, shared by both nn.Module and ScriptModule. - _build_decider([d_out, include_round])- Build the module which produces a image-level output. - _build_image_level_mlp(d_in, d_hidden, ...)- Build an MLP which acts on the image-level representations. - _build_latent_pixel_mlp(d_in, d_hidden, ...)- Build an MLP which acts on the latent-pixel-level representations. - Initialise the module weights. - _run_recorder_hook(hooks, hook_name, output)- forward(body_output)- Run the solo agent head on the given body output. - Get the state of the agent part as a dict. - set_state(checkpoint)- Set the state of the agent from a checkpoint. - to([device])- Move the agent head to the given device. - Attributes - T_destination- agent_id- The ID of the agent. - agent_level_in_keys- agent_level_out_keys- call_super_init- dump_patches- env_level_in_keys- env_level_out_keys- in_keys- The keys required by the module. - is_prover- Whether the agent is a prover. - is_verifier- Whether the agent is a verifier. - max_message_rounds- The maximum number of message rounds in the protocol. - num_visible_message_channels- The number of message channels visible to the agent. - out_keys- The keys produced by the module. - out_keys_source- required_pretrained_models- The pretrained models used by the agent. - visible_message_channel_indices- The indices of the message channels visible to the agent. - visible_message_channel_mask- The mask for the message channels visible to the agent. - visible_message_channel_names- The names of the message channels visible to the agent. - agent_params- training- Methods - __init__(hyper_params: HyperParameters, settings: ExperimentSettings, agent_name: str, protocol_handler: ProtocolHandler)[source]#
- Initialize internal Module state, shared by both nn.Module and ScriptModule. 
 - _build_decider(d_out: int = 3, include_round: bool | None = None) TensorDictModule[source]#
- Build the module which produces a image-level output. - By default it is used to decide whether to continue exchanging messages. In this case it outputs a single triple of logits for the three options: guess a classification for the image or continue exchanging messages. 
 - _build_image_level_mlp(d_in: int, d_hidden: int, d_out: int, num_layers: int, include_round: bool = False, out_key: str = 'image_level_mlp_output', squeeze: bool = False) TensorDictSequential[source]#
- Build an MLP which acts on the image-level representations. - Shapes - Input: - image_level_repr : (… d_in) 
 - Output: - image_level_mlp_output : (… d_out) 
 - Parameters:
- d_in (int) – The dimensionality of the image-level representations. 
- d_hidden (int) – The dimensionality of the hidden layers. 
- d_out (int) – The dimensionality of the output. 
- num_layers (int) – The number of hidden layers in the MLP. 
- include_round (bool, default=False) – Whether to include the round number as a (one-hot encoded) input to the MLP. 
- out_key (str, default="image_level_mlp_output") – The tensordict key to use for the output of the MLP. 
- squeeze (bool, default=False) – Whether to squeeze the output dimension. Only use this if the output dimension is 1. 
 
- Returns:
- image_level_mlp (TensorDictSequential) – The image-level MLP. 
 
 - _build_latent_pixel_mlp(d_in: int, d_hidden: int, d_out: int, num_layers: int, flatten_output: bool = True, out_key: str = 'latent_pixel_mlp_output') TensorDictModule[source]#
- Build an MLP which acts on the latent-pixel-level representations. - Shapes - Input: - “latent_pixel_level_repr” : (… latent_height latent_width d_in) 
 - Output: - latent_pixel_mlp_output : (… latent_height*latent_width d_out) 
 - Parameters:
- d_in (int) – The dimensionality of the input. 
- d_hidden (int) – The dimensionality of the hidden layers. 
- d_out (int) – The dimensionality of the output. 
- num_layers (int) – The number of hidden layers in the MLP. 
- flatten_output (bool, default=True) – Whether to flatten the output dimension to - latent_height * latent_width.
- out_key (str, default="latent_pixel_mlp_output") – The tensordict key to use for the output of the MLP. 
 
- Returns:
- latent_pixel_mlp (TensorDictModule) – The latent-pixel-level MLP. 
 
 - forward(body_output: TensorDict) TensorDict[source]#
- Run the solo agent head on the given body output. - Parameters:
- body_output (TensorDict) – - The output of the body module. A tensor dict with keys: - ”image_level_repr” (… d_representation): The output graph-level representations. 
 
- Returns:
- out (TensorDict) – A tensor dict with keys: - ”decision_logits” (… 2): A logit for each of the two options: guess that the graphs are isomorphic, or guess that the graphs are not isomorphic. 
 
 
 - get_state_dict() dict[source]#
- Get the state of the agent part as a dict. - This method should be implemented by subclasses capable of saving their state. - Returns:
- state_dict (dict) – The state of the agent part. 
 
 - set_state(checkpoint: AgentState)[source]#
- Set the state of the agent from a checkpoint. - This method should be overridden by subclasses to restore the state of the agent from a checkpoint. - Parameters:
- checkpoint (AgentCheckpoint) – The checkpoint to restore the state from.