nip.image_classification.agents.ImageClassificationCombinedBody#
- class nip.image_classification.agents.ImageClassificationCombinedBody(*args, **kwargs)[source]#
A module which combines the agent bodies for the image classification task.
Shapes
Input:
“round” (…): The round number.
“x” (… round channel position latent_height latent_width): The message history
“image” (… image_channel height width): The image
“message” (… channel position latent_height latent_width), optional: The most recent message.
“linear_message_history” : (… round channel position linear_message), optional: The linear message history, if using.
Output:
(“agents”, “latent_pixel_level_repr”) (… agents latent_height latent_width d_representation): The output latent-pixel-level representations.
(“agents”, “image_level_repr”) (… agents d_representation): The output image-level representations.
- Parameters:
hyper_params (HyperParameters) – The parameters of the experiment.
settings (ExperimentSettings) – The settings of the experiment.
protocol_handler (ProtocolHandler) – The protocol handler for the experiment.
bodies (dict[str, ImageClassificationAgentBody]) – The agent bodies to combine.
Notes
In all dimension annotations, “channel” refers to the the message channel dimension, which is how different groups of agents can communicate with each other. There is a terminology overlap with the channel dimension in images and convolutional layers. Such channels are called “image_channel” or “latent_channel” to avoid confusion.
Methods Summary
__init__(hyper_params, settings, ...)Initialize internal Module state, shared by both nn.Module and ScriptModule.
Restrict an agent's input to its visible message channels.
forward(data)Run the agent bodies and combines their outputs.
Attributes
T_destinationadditional_in_keysadditional_out_keyscall_super_initdeviceThe device used by the agent part.
dump_patchesexcluded_in_keysexcluded_out_keysin_keysThe keys required by the module.
out_keysThe keys produced by the module.
out_keys_sourcetrainingMethods
- __init__(hyper_params: HyperParameters, settings: ExperimentSettings, protocol_handler: ProtocolHandler, bodies: dict[str, AgentBody])[source]#
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- _restrict_input_to_visible_channels(agent_name: str, input_array: Tensor | ndarray[Any, dtype[_ScalarType_co]], shape_spec: str) Tensor[source]#
Restrict an agent’s input to its visible message channels.
Agents only receive messages from the channels they can see. This function restricts the input to the agent to only the visible message channels.
- Parameters:
- Returns:
restricted_input (Tensor | NDArray) – The input restricted to the visible message channels.
- forward(data: TensorDictBase) TensorDict[source]#
Run the agent bodies and combines their outputs.
- Parameters:
data (TensorDictBase) – The data to run the bodies on.
- Returns:
data (TensorDict) – The data updated in place with the output of the agent bodies.