nip.image_classification.agents.ImageClassificationCombinedValueHead#

class nip.image_classification.agents.ImageClassificationCombinedValueHead(*args, **kwargs)[source]#

A module which combines the agent value heads for the image classification task.

Shapes

Input:

  • “round” (…): The round number.

  • (“agents”, “latent_pixel_level_repr”) (… agents latent_height latent_width d_representation): The output latent-pixel-level representations.

  • (“agents”, “image_level_repr”) (… agents d_representation): The output image-level representations.

Output:

  • (“agents”, “value”) (… agents): The estimated value for each batch item

Parameters:

Methods Summary

__init__(hyper_params, settings, ...)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

_restrict_input_to_visible_channels(...)

Restrict an agent's input to its visible message channels.

forward(head_output)

Run the agent value heads and combine their values.

Attributes

T_destination

additional_in_keys

additional_out_keys

call_super_init

device

The device used by the agent part.

dump_patches

excluded_in_keys

excluded_out_keys

in_keys

The keys required by the module.

out_keys

The keys produced by the module.

out_keys_source

training

Methods

__init__(hyper_params: HyperParameters, settings: ExperimentSettings, protocol_handler: ProtocolHandler, value_heads: dict[str, AgentValueHead])[source]#

Initialize internal Module state, shared by both nn.Module and ScriptModule.

_restrict_input_to_visible_channels(agent_name: str, input_array: Tensor | ndarray[Any, dtype[_ScalarType_co]], shape_spec: str) Tensor[source]#

Restrict an agent’s input to its visible message channels.

Agents only receive messages from the channels they can see. This function restricts the input to the agent to only the visible message channels.

Parameters:
  • agent_name (str) – The name of the agent.

  • input_array (Tensor | NDArray) – The input array to the agent.

  • shape_spec (str) – The shape of the input. This is a space-separated string of the dimensions of the input. One of these must be “channel”.

Returns:

restricted_input (Tensor | NDArray) – The input restricted to the visible message channels.

forward(head_output: TensorDictBase) TensorDict[source]#

Run the agent value heads and combine their values.

Parameters:

tensordict (TensorDictBase) –

The input to the value heads. Should contain the keys:

  • (“agents”, “image_level_repr”): The node-level representation from the body.

Returns:

tensordict (TensorDict) – The tensordict update in place with the output of the value heads.