nip.image_classification.agents.ImageClassificationAgentBody#
- class nip.image_classification.agents.ImageClassificationAgentBody(*args, **kwargs)[source]#
The body of an image classification agent.
Takes as input the image, message history and the most recent message and outputs the image-level and latent pixel-level representations.
Shapes
Input:
“x” (… round channel position latent_height latent_width): The message history
“image” (… image_channel height width): The image
“message” (… channel position latent_height latent_width), optional: The most recent message
“ignore_message” (…), optional: Whether to ignore the message
(“pretrained_embeddings”, model_name) (… embedding_width embedding_height), optional: The embeddings of a pretrained model, if using.
“linear_message_history” : (… round channel position linear_message), optional: The linear message history, if using
Output:
“image_level_repr” (… d_representation): The output image-level representations.
“latent_pixel_level_repr” (… latent_height latent_width d_representation): The output latent-pixel-level representations.
- Parameters:
hyper_params (HyperParameters) – The parameters of the experiment.
settings (ExperimentSettings) – The settings of the experiment.
agent_name (str) – The name of the agent.
protocol_handler (ProtocolHandler) – The protocol handler for the experiment.
device (TorchDevice, optional) – The device to use for this agent part. If not given, the CPU is used.
Notes
In all dimension annotations, “channel” refers to the the message channel dimension, which is how different groups of agents can communicate with each other. There is a terminology overlap with the channel dimension in images and convolutional layers. Such channels are called “image_channel” or “latent_channel” to avoid confusion.
Methods Summary
__init__(hyper_params, settings, agent_name, ...)Initialize internal Module state, shared by both nn.Module and ScriptModule.
Initialise the module weights.
_run_recorder_hook(hooks, hook_name, output)Build the the sequence of groups of building blocks.
Build the final encoder.
Build the global pooling layer.
Build the initial encoding layer.
Build the module which upsamples the history and message.
Build the module which scales the pretrained embeddings.
forward(data)Run the image classification body.
Get the state of the agent part as a dict.
set_state(checkpoint)Set the state of the agent from a checkpoint.
to([device])Move the agent body to a new device.
Attributes
T_destinationagent_idThe ID of the agent.
agent_level_in_keysagent_level_out_keyscall_super_initdump_patchesenv_level_in_keysThe environment-level input keys for the agent body.
env_level_out_keysin_keysThe keys required by the module.
include_pretrained_embeddingsWhether to include pretrained embeddings.
is_proverWhether the agent is a prover.
is_verifierWhether the agent is a verifier.
max_message_roundsThe maximum number of message rounds in the protocol.
num_visible_message_channelsThe number of message channels visible to the agent.
out_keysThe keys produced by the module.
out_keys_sourcepretrained_model_classThe pretrained model class to use, if any.
pretrained_model_nameThe full name of the pretrained model to use, if any.
required_pretrained_modelsThe pretrained models required by the agent body.
visible_message_channel_indicesThe indices of the message channels visible to the agent.
visible_message_channel_maskThe mask for the message channels visible to the agent.
visible_message_channel_namesThe names of the message channels visible to the agent.
agent_paramstrainingMethods
- __init__(hyper_params: HyperParameters, settings: ExperimentSettings, agent_name: str, protocol_handler: ProtocolHandler)[source]#
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- build_cnn_encoder() TensorDictSequential[source]#
Build the the sequence of groups of building blocks.
Shapes
Input:
“latent_pixel_level_repr” : (… initial_channels height width)
Output:
“latent_pixel_level_repr” : (… latent_channels latent_height latent_width)
where
latent_channels = initial_channels * 2**num_block_groups- Returns:
cnn_encoder (TensorDictSequential) – The sequence of groups of building blocks.
- build_final_encoder() TensorDictSequential[source]#
Build the final encoder.
This rearranges the latent pixel-level representations to put the channel dimension last, then applies a linear layer to obtain the final latent pixel-level representations. It also concatenates the image-level representations with the linear message history if using, then applies a linear layer to obtain the final image-level representations.
Shapes
Input:
“image_level_repr” : (… latent_channels+num_message_channels*message_size)
“latent_pixel_level_repr” : (… latent_channels+num_message_channels*message_size latent_height latent_width)
“linear_message_history” : (… round channel position linear_message), optional
Output:
“image_level_repr” : (… d_representation)
“latent_pixel_level_repr” : (… latent_height latent_width d_representation)
- Returns:
final_encoder (TensorDictSequential) – The final encoder.
- build_global_pooling() TensorDictModule[source]#
Build the global pooling layer.
Shapes
Input:
“latent_pixel_level_repr” : (… latent_channels+1 latent_height latent_width)
Output:
“image_level_repr” : (… latent_channels+1)
- Returns:
global_pooling (TensorDictModule) – The global pooling layer.
- build_initial_encoder() TensorDictSequential[source]#
Build the initial encoding layer.
Concatenates the upsampled message history with the image and pretrained embeddings if using, then applies a two-layer MLP to obtain the initial pixel-level representations.
Shapes
Input:
“x_upsampled” : (… round channel position height width)
“image” : (… image_channel height width)
“pretrained_embeddings_scaled” : (… embedding_channels height width), optional
Output:
“latent_pixel_level_repr” : (… initial_num_channels height width)
- Returns:
TensorDictSequential – The initial encoding layer.
- build_message_history_upsampler() TensorDictModule[source]#
Build the module which upsamples the history and message.
The message history is upsampled to the size of the image.
Shapes
Input:
“x” : (… round channel position latent_height latent_width)
Output:
“x_upsampled” : (… round channel position height width)
- Returns:
message_history_upsampler (TensorDictModule) – The module which upsamples the message history to the size of the image.
- build_pretrained_embedding_scaler() TensorDictModule[source]#
Build the module which scales the pretrained embeddings.
The pretrained embeddings scaled to the image size. This can be by upsampling or mean pooling, depending on whether the image size is larger or smaller than the embedding size.
Shapes
Input:
“pretrained_embeddings” : (… embedding_channel embedding_height embedding_width) : The embeddings of the pretrained model
Output:
“pretrained_embeddings_scaled” : (… embedding_channel height width) : The scaled embeddings
- Returns:
pretrained_embedding_scaler (TensorDictModule) – The module which scales the pretrained embeddings.
- forward(data: TensorDictBase) TensorDict[source]#
Run the image classification body.
- Parameters:
data (TensorDictBase) –
The data to run the body on. A TensorDictBase with keys:
”x” (… round channel position latent_height latent_width): The message history
”image” (… image_channel height width): The image
”message” (… channel position latent_height latent_width), optional: The most recent message
”ignore_message” (…), optional: Whether to ignore the message. For example, in the first round the there is no message, and the “message” field is set to a dummy value.
(“pretrained_embeddings”, model_name) (… embedding_width embedding_height), optional: The embeddings of a pretrained model, if using.
”linear_message_history” : (… round channel position linear_message), optional: The linear message history, if using.
- Returns:
out (TensorDict) – A tensor dict with keys:
”image_level_repr” (… d_representation): The output image-level representations.
”latent_pixel_level_repr” (… latent_height latent_width d_representation): The output latent-pixel-level representations.
- get_state_dict() dict[source]#
Get the state of the agent part as a dict.
This method should be implemented by subclasses capable of saving their state.
- Returns:
state_dict (dict) – The state of the agent part.
- set_state(checkpoint: AgentState)[source]#
Set the state of the agent from a checkpoint.
This method should be overridden by subclasses to restore the state of the agent from a checkpoint.
- Parameters:
checkpoint (AgentCheckpoint) – The checkpoint to restore the state from.
- to(device: device | str | int | None = None) ImageClassificationAgentBody[source]#
Move the agent body to a new device.
- Parameters:
device (TorchDevice, optional) – The device to move the agent body to. If not given, the CPU is used.
- Returns:
self (ImageClassificationAgentBody) – The agent body on the new device.