nip.image_classification.agents.ImageClassificationAgentValueHead#

class nip.image_classification.agents.ImageClassificationAgentValueHead(*args, **kwargs)[source]#

Value head for the image classification task.

Takes as input the output of the agent body and outputs a value function.

Shapes

Input:

  • “image_level_repr” (… d_representation): The output image-level representations.

  • “round” (optional) (…): The round number.

Output:

  • “value” (…): The estimated value for each batch item

Parameters:
  • hyper_params (HyperParameters) – The parameters of the experiment.

  • settings (ExperimentSettings) – The settings of the experiment.

  • agent_name (str) – The name of the agent.

  • protocol_handler (ProtocolHandler) – The protocol handler for the experiment.

  • device (TorchDevice, optional) – The device to use for this agent part. If not given, the CPU is used.

Methods Summary

__init__(hyper_params, settings, agent_name, ...)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

_build_decider([d_out, include_round])

Build the module which produces a image-level output.

_build_image_level_mlp(d_in, d_hidden, ...)

Build an MLP which acts on the image-level representations.

_build_latent_pixel_mlp(d_in, d_hidden, ...)

Build an MLP which acts on the latent-pixel-level representations.

_build_mlp()

Build the module which computes the value function.

_init_weights()

Initialise the module weights.

_run_recorder_hook(hooks, hook_name, output)

forward(body_output)

Run the value head on the given body output.

get_state_dict()

Get the state of the agent part as a dict.

set_state(checkpoint)

Set the state of the agent from a checkpoint.

to([device])

Move the agent head to the given device.

Attributes

T_destination

agent_id

The ID of the agent.

agent_level_in_keys

agent_level_out_keys

call_super_init

dump_patches

env_level_in_keys

The environment-level input keys.

env_level_out_keys

in_keys

The keys required by the module.

is_prover

Whether the agent is a prover.

is_verifier

Whether the agent is a verifier.

max_message_rounds

The maximum number of message rounds in the protocol.

num_visible_message_channels

The number of message channels visible to the agent.

out_keys

The keys produced by the module.

out_keys_source

required_pretrained_models

The pretrained models used by the agent.

visible_message_channel_indices

The indices of the message channels visible to the agent.

visible_message_channel_mask

The mask for the message channels visible to the agent.

visible_message_channel_names

The names of the message channels visible to the agent.

agent_params

training

Methods

__init__(hyper_params: HyperParameters, settings: ExperimentSettings, agent_name: str, protocol_handler: ProtocolHandler)[source]#

Initialize internal Module state, shared by both nn.Module and ScriptModule.

_build_decider(d_out: int = 3, include_round: bool | None = None) TensorDictModule[source]#

Build the module which produces a image-level output.

By default it is used to decide whether to continue exchanging messages. In this case it outputs a single triple of logits for the three options: guess a classification for the image or continue exchanging messages.

Parameters:
  • d_out (int, default=3) – The dimensionality of the output.

  • include_round (bool, optional) – Whether to include the round number as a (one-hot encoded) input to the MLP. If not given, the value from the agent parameters is used.

Returns:

decider (TensorDictModule) – The decider module.

_build_image_level_mlp(d_in: int, d_hidden: int, d_out: int, num_layers: int, include_round: bool = False, out_key: str = 'image_level_mlp_output', squeeze: bool = False) TensorDictSequential[source]#

Build an MLP which acts on the image-level representations.

Shapes

Input:

  • image_level_repr : (… d_in)

Output:

  • image_level_mlp_output : (… d_out)

Parameters:
  • d_in (int) – The dimensionality of the image-level representations.

  • d_hidden (int) – The dimensionality of the hidden layers.

  • d_out (int) – The dimensionality of the output.

  • num_layers (int) – The number of hidden layers in the MLP.

  • include_round (bool, default=False) – Whether to include the round number as a (one-hot encoded) input to the MLP.

  • out_key (str, default="image_level_mlp_output") – The tensordict key to use for the output of the MLP.

  • squeeze (bool, default=False) – Whether to squeeze the output dimension. Only use this if the output dimension is 1.

Returns:

image_level_mlp (TensorDictSequential) – The image-level MLP.

_build_latent_pixel_mlp(d_in: int, d_hidden: int, d_out: int, num_layers: int, flatten_output: bool = True, out_key: str = 'latent_pixel_mlp_output') TensorDictModule[source]#

Build an MLP which acts on the latent-pixel-level representations.

Shapes

Input:

  • “latent_pixel_level_repr” : (… latent_height latent_width d_in)

Output:

  • latent_pixel_mlp_output : (… latent_height*latent_width d_out)

Parameters:
  • d_in (int) – The dimensionality of the input.

  • d_hidden (int) – The dimensionality of the hidden layers.

  • d_out (int) – The dimensionality of the output.

  • num_layers (int) – The number of hidden layers in the MLP.

  • flatten_output (bool, default=True) – Whether to flatten the output dimension to latent_height * latent_width.

  • out_key (str, default="latent_pixel_mlp_output") – The tensordict key to use for the output of the MLP.

Returns:

latent_pixel_mlp (TensorDictModule) – The latent-pixel-level MLP.

_build_mlp() TensorDictModule[source]#

Build the module which computes the value function.

Returns:

value_mlp (TensorDictModule) – The value module.

_init_weights()[source]#

Initialise the module weights.

Should be called at the end of __init__

_run_recorder_hook(hooks: AgentHooks | None, hook_name: str, output: Tensor | None)[source]#
forward(body_output: TensorDict) TensorDict[source]#

Run the value head on the given body output.

Parameters:

body_output (TensorDict) –

The output of the body module. A tensor dict with keys:

  • ”image_level_repr” (… d_representation): The output graph-level representations.

Returns:

value_out (TensorDict) – A tensor dict with keys:

  • ”value” (…): The estimated value for each batch item

get_state_dict() dict[source]#

Get the state of the agent part as a dict.

This method should be implemented by subclasses capable of saving their state.

Returns:

state_dict (dict) – The state of the agent part.

set_state(checkpoint: AgentState)[source]#

Set the state of the agent from a checkpoint.

This method should be overridden by subclasses to restore the state of the agent from a checkpoint.

Parameters:

checkpoint (AgentCheckpoint) – The checkpoint to restore the state from.

to(device: device | str | int | None = None)[source]#

Move the agent head to the given device.

Parameters:

device (TorchDevice, optional) – The device to move the agent head to. If not given, the CPU is used.