nip.image_classification.agents.ImageClassificationRandomAgentPolicyHead#

class nip.image_classification.agents.ImageClassificationRandomAgentPolicyHead(*args, **kwargs)[source]#

Policy head for the image classification task yielding a uniform distribution.

Shapes

Input:

  • “image_level_repr” (… d_representation): The output image-level representations.

  • “latent_pixel_level_repr” (… latent_height latent_width d_representation): The output latent-pixel-level representations.

Output:

  • “latent_pixel_selected_logits” (… channel position latent_height*latent_width): A logit for each latent pixel, indicating the probability that this latent pixel should be sent as a message to the verifier.

  • “decision_logits” (… 3): A logit for each of the three options: guess a classification one way or the other, or continue exchanging messages. Set to zeros when the decider is not present.

  • “linear_message_selected_logits” (… channel position d_linear_message_space) (optional): A logit for each linear message, indicating the probability that this linear message should be sent as a message to the verifier.

Methods Summary

__init__(hyper_params, settings, agent_name, ...)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

_init_weights()

Initialise the module weights.

_run_recorder_hook(hooks, hook_name, output)

forward(body_output)

Output a uniform distribution.

get_state_dict()

Get the state of the agent part as a dict.

set_state(checkpoint)

Set the state of the agent from a checkpoint.

to(device)

Move the agent to the given device.

Attributes

T_destination

agent_id

The ID of the agent.

agent_level_in_keys

agent_level_out_keys

The agent-level output keys.

call_super_init

dump_patches

env_level_in_keys

env_level_out_keys

has_decider

Whether the policy head has an output yielding a decision.

in_keys

The keys required by the module.

is_prover

Whether the agent is a prover.

is_verifier

Whether the agent is a verifier.

max_message_rounds

The maximum number of message rounds in the protocol.

num_visible_message_channels

The number of message channels visible to the agent.

out_keys

The keys produced by the module.

out_keys_source

required_pretrained_models

The pretrained models used by the agent.

visible_message_channel_indices

The indices of the message channels visible to the agent.

visible_message_channel_mask

The mask for the message channels visible to the agent.

visible_message_channel_names

The names of the message channels visible to the agent.

agent_params

training

Methods

__init__(hyper_params: HyperParameters, settings: ExperimentSettings, agent_name: str, protocol_handler: ProtocolHandler)[source]#

Initialize internal Module state, shared by both nn.Module and ScriptModule.

_init_weights()[source]#

Initialise the module weights.

Should be called at the end of __init__

_run_recorder_hook(hooks: AgentHooks | None, hook_name: str, output: Tensor | None)[source]#
forward(body_output: TensorDict) TensorDict[source]#

Output a uniform distribution.

Parameters:

body_output (TensorDict) – The output of the body module.

Returns:

out (TensorDict) – A tensor dict with all zero outputs.

get_state_dict() dict[source]#

Get the state of the agent part as a dict.

This method should be implemented by subclasses capable of saving their state.

Returns:

state_dict (dict) – The state of the agent part.

set_state(checkpoint: AgentState)[source]#

Set the state of the agent from a checkpoint.

This method should be overridden by subclasses to restore the state of the agent from a checkpoint.

Parameters:

checkpoint (AgentCheckpoint) – The checkpoint to restore the state from.

to(device: device | str | int)[source]#

Move the agent to the given device.