nip.image_classification.agents.ImageClassificationAgentBody#

class nip.image_classification.agents.ImageClassificationAgentBody(*args, **kwargs)[source]#

The body of an image classification agent.

Takes as input the image, message history and the most recent message and outputs the image-level and latent pixel-level representations.

Shapes

Input:

  • “x” (… round channel position latent_height latent_width): The message history

  • “image” (… image_channel height width): The image

  • “message” (… channel position latent_height latent_width), optional: The most recent message

  • “ignore_message” (…), optional: Whether to ignore the message

  • (“pretrained_embeddings”, model_name) (… embedding_width embedding_height), optional: The embeddings of a pretrained model, if using.

  • “linear_message_history” : (… round channel position linear_message), optional: The linear message history, if using

Output:

  • “image_level_repr” (… d_representation): The output image-level representations.

  • “latent_pixel_level_repr” (… latent_height latent_width d_representation): The output latent-pixel-level representations.

Parameters:
  • hyper_params (HyperParameters) – The parameters of the experiment.

  • settings (ExperimentSettings) – The settings of the experiment.

  • agent_name (str) – The name of the agent.

  • protocol_handler (ProtocolHandler) – The protocol handler for the experiment.

  • device (TorchDevice, optional) – The device to use for this agent part. If not given, the CPU is used.

Notes

In all dimension annotations, “channel” refers to the the message channel dimension, which is how different groups of agents can communicate with each other. There is a terminology overlap with the channel dimension in images and convolutional layers. Such channels are called “image_channel” or “latent_channel” to avoid confusion.

Methods Summary

__init__(hyper_params, settings, agent_name, ...)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

_init_weights()

Initialise the module weights.

_run_recorder_hook(hooks, hook_name, output)

build_cnn_encoder()

Build the the sequence of groups of building blocks.

build_final_encoder()

Build the final encoder.

build_global_pooling()

Build the global pooling layer.

build_initial_encoder()

Build the initial encoding layer.

build_message_history_upsampler()

Build the module which upsamples the history and message.

build_pretrained_embedding_scaler()

Build the module which scales the pretrained embeddings.

forward(data)

Run the image classification body.

get_state_dict()

Get the state of the agent part as a dict.

set_state(checkpoint)

Set the state of the agent from a checkpoint.

to([device])

Move the agent body to a new device.

Attributes

T_destination

agent_id

The ID of the agent.

agent_level_in_keys

agent_level_out_keys

call_super_init

dump_patches

env_level_in_keys

The environment-level input keys for the agent body.

env_level_out_keys

in_keys

The keys required by the module.

include_pretrained_embeddings

Whether to include pretrained embeddings.

is_prover

Whether the agent is a prover.

is_verifier

Whether the agent is a verifier.

max_message_rounds

The maximum number of message rounds in the protocol.

num_visible_message_channels

The number of message channels visible to the agent.

out_keys

The keys produced by the module.

out_keys_source

pretrained_model_class

The pretrained model class to use, if any.

pretrained_model_name

The full name of the pretrained model to use, if any.

required_pretrained_models

The pretrained models required by the agent body.

visible_message_channel_indices

The indices of the message channels visible to the agent.

visible_message_channel_mask

The mask for the message channels visible to the agent.

visible_message_channel_names

The names of the message channels visible to the agent.

agent_params

training

Methods

__init__(hyper_params: HyperParameters, settings: ExperimentSettings, agent_name: str, protocol_handler: ProtocolHandler)[source]#

Initialize internal Module state, shared by both nn.Module and ScriptModule.

_init_weights()[source]#

Initialise the module weights.

Should be called at the end of __init__

_run_recorder_hook(hooks: AgentHooks | None, hook_name: str, output: Tensor | None)[source]#
build_cnn_encoder() TensorDictSequential[source]#

Build the the sequence of groups of building blocks.

Shapes

Input:

  • “latent_pixel_level_repr” : (… initial_channels height width)

Output:

  • “latent_pixel_level_repr” : (… latent_channels latent_height latent_width)

where latent_channels = initial_channels * 2**num_block_groups

Returns:

cnn_encoder (TensorDictSequential) – The sequence of groups of building blocks.

build_final_encoder() TensorDictSequential[source]#

Build the final encoder.

This rearranges the latent pixel-level representations to put the channel dimension last, then applies a linear layer to obtain the final latent pixel-level representations. It also concatenates the image-level representations with the linear message history if using, then applies a linear layer to obtain the final image-level representations.

Shapes

Input:

  • “image_level_repr” : (… latent_channels+num_message_channels*message_size)

  • “latent_pixel_level_repr” : (… latent_channels+num_message_channels*message_size latent_height latent_width)

  • “linear_message_history” : (… round channel position linear_message), optional

Output:

  • “image_level_repr” : (… d_representation)

  • “latent_pixel_level_repr” : (… latent_height latent_width d_representation)

Returns:

final_encoder (TensorDictSequential) – The final encoder.

build_global_pooling() TensorDictModule[source]#

Build the global pooling layer.

Shapes

Input:

  • “latent_pixel_level_repr” : (… latent_channels+1 latent_height latent_width)

Output:

  • “image_level_repr” : (… latent_channels+1)

Returns:

global_pooling (TensorDictModule) – The global pooling layer.

build_initial_encoder() TensorDictSequential[source]#

Build the initial encoding layer.

Concatenates the upsampled message history with the image and pretrained embeddings if using, then applies a two-layer MLP to obtain the initial pixel-level representations.

Shapes

Input:

  • “x_upsampled” : (… round channel position height width)

  • “image” : (… image_channel height width)

  • “pretrained_embeddings_scaled” : (… embedding_channels height width), optional

Output:

  • “latent_pixel_level_repr” : (… initial_num_channels height width)

Returns:

TensorDictSequential – The initial encoding layer.

build_message_history_upsampler() TensorDictModule[source]#

Build the module which upsamples the history and message.

The message history is upsampled to the size of the image.

Shapes

Input:

  • “x” : (… round channel position latent_height latent_width)

Output:

  • “x_upsampled” : (… round channel position height width)

Returns:

message_history_upsampler (TensorDictModule) – The module which upsamples the message history to the size of the image.

build_pretrained_embedding_scaler() TensorDictModule[source]#

Build the module which scales the pretrained embeddings.

The pretrained embeddings scaled to the image size. This can be by upsampling or mean pooling, depending on whether the image size is larger or smaller than the embedding size.

Shapes

Input:

  • “pretrained_embeddings” : (… embedding_channel embedding_height embedding_width) : The embeddings of the pretrained model

Output:

  • “pretrained_embeddings_scaled” : (… embedding_channel height width) : The scaled embeddings

Returns:

pretrained_embedding_scaler (TensorDictModule) – The module which scales the pretrained embeddings.

forward(data: TensorDictBase) TensorDict[source]#

Run the image classification body.

Parameters:

data (TensorDictBase) –

The data to run the body on. A TensorDictBase with keys:

  • ”x” (… round channel position latent_height latent_width): The message history

  • ”image” (… image_channel height width): The image

  • ”message” (… channel position latent_height latent_width), optional: The most recent message

  • ”ignore_message” (…), optional: Whether to ignore the message. For example, in the first round the there is no message, and the “message” field is set to a dummy value.

  • (“pretrained_embeddings”, model_name) (… embedding_width embedding_height), optional: The embeddings of a pretrained model, if using.

  • ”linear_message_history” : (… round channel position linear_message), optional: The linear message history, if using.

Returns:

out (TensorDict) – A tensor dict with keys:

  • ”image_level_repr” (… d_representation): The output image-level representations.

  • ”latent_pixel_level_repr” (… latent_height latent_width d_representation): The output latent-pixel-level representations.

get_state_dict() dict[source]#

Get the state of the agent part as a dict.

This method should be implemented by subclasses capable of saving their state.

Returns:

state_dict (dict) – The state of the agent part.

set_state(checkpoint: AgentState)[source]#

Set the state of the agent from a checkpoint.

This method should be overridden by subclasses to restore the state of the agent from a checkpoint.

Parameters:

checkpoint (AgentCheckpoint) – The checkpoint to restore the state from.

to(device: device | str | int | None = None) ImageClassificationAgentBody[source]#

Move the agent body to a new device.

Parameters:

device (TorchDevice, optional) – The device to move the agent body to. If not given, the CPU is used.

Returns:

self (ImageClassificationAgentBody) – The agent body on the new device.