nip.image_classification.agents.ImageClassificationSoloAgentHead#
- class nip.image_classification.agents.ImageClassificationSoloAgentHead(*args, **kwargs)[source]#
Solo agent head for the image classification task.
Solo agents try to solve the task on their own, without interacting with another agents.
Shapes
Input:
“image_level_repr” (… d_representation): The output image-level representations.
Output:
“decision_logits” (… 2): A logit for each of the two options: guess that the graphs are isomorphic, or guess that the graphs are not isomorphic.
Methods Summary
__init__
(hyper_params, settings, agent_name, ...)Initialize internal Module state, shared by both nn.Module and ScriptModule.
_build_decider
([d_out, include_round])Build the module which produces a image-level output.
_build_image_level_mlp
(d_in, d_hidden, ...)Build an MLP which acts on the image-level representations.
_build_latent_pixel_mlp
(d_in, d_hidden, ...)Build an MLP which acts on the latent-pixel-level representations.
Initialise the module weights.
_run_recorder_hook
(hooks, hook_name, output)forward
(body_output)Run the solo agent head on the given body output.
Get the state of the agent part as a dict.
set_state
(checkpoint)Set the state of the agent from a checkpoint.
to
([device])Move the agent head to the given device.
Attributes
T_destination
agent_id
The ID of the agent.
agent_level_in_keys
agent_level_out_keys
call_super_init
dump_patches
env_level_in_keys
env_level_out_keys
in_keys
The keys required by the module.
is_prover
Whether the agent is a prover.
is_verifier
Whether the agent is a verifier.
max_message_rounds
The maximum number of message rounds in the protocol.
num_visible_message_channels
The number of message channels visible to the agent.
out_keys
The keys produced by the module.
out_keys_source
required_pretrained_models
The pretrained models used by the agent.
visible_message_channel_indices
The indices of the message channels visible to the agent.
visible_message_channel_mask
The mask for the message channels visible to the agent.
visible_message_channel_names
The names of the message channels visible to the agent.
agent_params
training
Methods
- __init__(hyper_params: HyperParameters, settings: ExperimentSettings, agent_name: str, protocol_handler: ProtocolHandler)[source]#
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- _build_decider(d_out: int = 3, include_round: bool | None = None) TensorDictModule [source]#
Build the module which produces a image-level output.
By default it is used to decide whether to continue exchanging messages. In this case it outputs a single triple of logits for the three options: guess a classification for the image or continue exchanging messages.
- _build_image_level_mlp(d_in: int, d_hidden: int, d_out: int, num_layers: int, include_round: bool = False, out_key: str = 'image_level_mlp_output', squeeze: bool = False) TensorDictSequential [source]#
Build an MLP which acts on the image-level representations.
Shapes
Input:
image_level_repr : (… d_in)
Output:
image_level_mlp_output : (… d_out)
- Parameters:
d_in (int) – The dimensionality of the image-level representations.
d_hidden (int) – The dimensionality of the hidden layers.
d_out (int) – The dimensionality of the output.
num_layers (int) – The number of hidden layers in the MLP.
include_round (bool, default=False) – Whether to include the round number as a (one-hot encoded) input to the MLP.
out_key (str, default="image_level_mlp_output") – The tensordict key to use for the output of the MLP.
squeeze (bool, default=False) – Whether to squeeze the output dimension. Only use this if the output dimension is 1.
- Returns:
image_level_mlp (TensorDictSequential) – The image-level MLP.
- _build_latent_pixel_mlp(d_in: int, d_hidden: int, d_out: int, num_layers: int, flatten_output: bool = True, out_key: str = 'latent_pixel_mlp_output') TensorDictModule [source]#
Build an MLP which acts on the latent-pixel-level representations.
Shapes
Input:
“latent_pixel_level_repr” : (… latent_height latent_width d_in)
Output:
latent_pixel_mlp_output : (… latent_height*latent_width d_out)
- Parameters:
d_in (int) – The dimensionality of the input.
d_hidden (int) – The dimensionality of the hidden layers.
d_out (int) – The dimensionality of the output.
num_layers (int) – The number of hidden layers in the MLP.
flatten_output (bool, default=True) – Whether to flatten the output dimension to
latent_height * latent_width
.out_key (str, default="latent_pixel_mlp_output") – The tensordict key to use for the output of the MLP.
- Returns:
latent_pixel_mlp (TensorDictModule) – The latent-pixel-level MLP.
- forward(body_output: TensorDict) TensorDict [source]#
Run the solo agent head on the given body output.
- Parameters:
body_output (TensorDict) –
The output of the body module. A tensor dict with keys:
”image_level_repr” (… d_representation): The output graph-level representations.
- Returns:
out (TensorDict) – A tensor dict with keys:
”decision_logits” (… 2): A logit for each of the two options: guess that the graphs are isomorphic, or guess that the graphs are not isomorphic.
- get_state_dict() dict [source]#
Get the state of the agent part as a dict.
This method should be implemented by subclasses capable of saving their state.
- Returns:
state_dict (dict) – The state of the agent part.
- set_state(checkpoint: AgentState)[source]#
Set the state of the agent from a checkpoint.
This method should be overridden by subclasses to restore the state of the agent from a checkpoint.
- Parameters:
checkpoint (AgentCheckpoint) – The checkpoint to restore the state from.