nip.image_classification.agents.ImageClassificationAgentBody#
- class nip.image_classification.agents.ImageClassificationAgentBody(*args, **kwargs)[source]#
The body of an image classification agent.
Takes as input the image, message history and the most recent message and outputs the image-level and latent pixel-level representations.
Shapes
Input:
“x” (… round channel position latent_height latent_width): The message history
“image” (… image_channel height width): The image
“message” (… channel position latent_height latent_width), optional: The most recent message
“ignore_message” (…), optional: Whether to ignore the message
(“pretrained_embeddings”, model_name) (… embedding_width embedding_height), optional: The embeddings of a pretrained model, if using.
“linear_message_history” : (… round channel position linear_message), optional: The linear message history, if using
Output:
“image_level_repr” (… d_representation): The output image-level representations.
“latent_pixel_level_repr” (… latent_height latent_width d_representation): The output latent-pixel-level representations.
- Parameters:
hyper_params (HyperParameters) – The parameters of the experiment.
settings (ExperimentSettings) – The settings of the experiment.
agent_name (str) – The name of the agent.
protocol_handler (ProtocolHandler) – The protocol handler for the experiment.
device (TorchDevice, optional) – The device to use for this agent part. If not given, the CPU is used.
Notes
In all dimension annotations, “channel” refers to the the message channel dimension, which is how different groups of agents can communicate with each other. There is a terminology overlap with the channel dimension in images and convolutional layers. Such channels are called “image_channel” or “latent_channel” to avoid confusion.
Methods Summary
__init__
(hyper_params, settings, agent_name, ...)Initialize internal Module state, shared by both nn.Module and ScriptModule.
Initialise the module weights.
_run_recorder_hook
(hooks, hook_name, output)Build the the sequence of groups of building blocks.
Build the final encoder.
Build the global pooling layer.
Build the initial encoding layer.
Build the module which upsamples the history and message.
Build the module which scales the pretrained embeddings.
forward
(data)Run the image classification body.
Get the state of the agent part as a dict.
set_state
(checkpoint)Set the state of the agent from a checkpoint.
to
([device])Move the agent body to a new device.
Attributes
T_destination
agent_id
The ID of the agent.
agent_level_in_keys
agent_level_out_keys
call_super_init
dump_patches
env_level_in_keys
The environment-level input keys for the agent body.
env_level_out_keys
in_keys
The keys required by the module.
include_pretrained_embeddings
Whether to include pretrained embeddings.
is_prover
Whether the agent is a prover.
is_verifier
Whether the agent is a verifier.
max_message_rounds
The maximum number of message rounds in the protocol.
num_visible_message_channels
The number of message channels visible to the agent.
out_keys
The keys produced by the module.
out_keys_source
pretrained_model_class
The pretrained model class to use, if any.
pretrained_model_name
The full name of the pretrained model to use, if any.
required_pretrained_models
The pretrained models required by the agent body.
visible_message_channel_indices
The indices of the message channels visible to the agent.
visible_message_channel_mask
The mask for the message channels visible to the agent.
visible_message_channel_names
The names of the message channels visible to the agent.
agent_params
training
Methods
- __init__(hyper_params: HyperParameters, settings: ExperimentSettings, agent_name: str, protocol_handler: ProtocolHandler)[source]#
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- build_cnn_encoder() TensorDictSequential [source]#
Build the the sequence of groups of building blocks.
Shapes
Input:
“latent_pixel_level_repr” : (… initial_channels height width)
Output:
“latent_pixel_level_repr” : (… latent_channels latent_height latent_width)
where
latent_channels = initial_channels * 2**num_block_groups
- Returns:
cnn_encoder (TensorDictSequential) – The sequence of groups of building blocks.
- build_final_encoder() TensorDictSequential [source]#
Build the final encoder.
This rearranges the latent pixel-level representations to put the channel dimension last, then applies a linear layer to obtain the final latent pixel-level representations. It also concatenates the image-level representations with the linear message history if using, then applies a linear layer to obtain the final image-level representations.
Shapes
Input:
“image_level_repr” : (… latent_channels+num_message_channels*message_size)
“latent_pixel_level_repr” : (… latent_channels+num_message_channels*message_size latent_height latent_width)
“linear_message_history” : (… round channel position linear_message), optional
Output:
“image_level_repr” : (… d_representation)
“latent_pixel_level_repr” : (… latent_height latent_width d_representation)
- Returns:
final_encoder (TensorDictSequential) – The final encoder.
- build_global_pooling() TensorDictModule [source]#
Build the global pooling layer.
Shapes
Input:
“latent_pixel_level_repr” : (… latent_channels+1 latent_height latent_width)
Output:
“image_level_repr” : (… latent_channels+1)
- Returns:
global_pooling (TensorDictModule) – The global pooling layer.
- build_initial_encoder() TensorDictSequential [source]#
Build the initial encoding layer.
Concatenates the upsampled message history with the image and pretrained embeddings if using, then applies a two-layer MLP to obtain the initial pixel-level representations.
Shapes
Input:
“x_upsampled” : (… round channel position height width)
“image” : (… image_channel height width)
“pretrained_embeddings_scaled” : (… embedding_channels height width), optional
Output:
“latent_pixel_level_repr” : (… initial_num_channels height width)
- Returns:
TensorDictSequential – The initial encoding layer.
- build_message_history_upsampler() TensorDictModule [source]#
Build the module which upsamples the history and message.
The message history is upsampled to the size of the image.
Shapes
Input:
“x” : (… round channel position latent_height latent_width)
Output:
“x_upsampled” : (… round channel position height width)
- Returns:
message_history_upsampler (TensorDictModule) – The module which upsamples the message history to the size of the image.
- build_pretrained_embedding_scaler() TensorDictModule [source]#
Build the module which scales the pretrained embeddings.
The pretrained embeddings scaled to the image size. This can be by upsampling or mean pooling, depending on whether the image size is larger or smaller than the embedding size.
Shapes
Input:
“pretrained_embeddings” : (… embedding_channel embedding_height embedding_width) : The embeddings of the pretrained model
Output:
“pretrained_embeddings_scaled” : (… embedding_channel height width) : The scaled embeddings
- Returns:
pretrained_embedding_scaler (TensorDictModule) – The module which scales the pretrained embeddings.
- forward(data: TensorDictBase) TensorDict [source]#
Run the image classification body.
- Parameters:
data (TensorDictBase) –
The data to run the body on. A TensorDictBase with keys:
”x” (… round channel position latent_height latent_width): The message history
”image” (… image_channel height width): The image
”message” (… channel position latent_height latent_width), optional: The most recent message
”ignore_message” (…), optional: Whether to ignore the message. For example, in the first round the there is no message, and the “message” field is set to a dummy value.
(“pretrained_embeddings”, model_name) (… embedding_width embedding_height), optional: The embeddings of a pretrained model, if using.
”linear_message_history” : (… round channel position linear_message), optional: The linear message history, if using.
- Returns:
out (TensorDict) – A tensor dict with keys:
”image_level_repr” (… d_representation): The output image-level representations.
”latent_pixel_level_repr” (… latent_height latent_width d_representation): The output latent-pixel-level representations.
- get_state_dict() dict [source]#
Get the state of the agent part as a dict.
This method should be implemented by subclasses capable of saving their state.
- Returns:
state_dict (dict) – The state of the agent part.
- set_state(checkpoint: AgentState)[source]#
Set the state of the agent from a checkpoint.
This method should be overridden by subclasses to restore the state of the agent from a checkpoint.
- Parameters:
checkpoint (AgentCheckpoint) – The checkpoint to restore the state from.
- to(device: device | str | int | None = None) ImageClassificationAgentBody [source]#
Move the agent body to a new device.
- Parameters:
device (TorchDevice, optional) – The device to move the agent body to. If not given, the CPU is used.
- Returns:
self (ImageClassificationAgentBody) – The agent body on the new device.