nip.image_classification.agents.ImageClassificationDummyAgentBody#
- class nip.image_classification.agents.ImageClassificationDummyAgentBody(*args, **kwargs)[source]#
Dummy agent body for the image classification task.
Shapes
Input:
“x” (… max_message_rounds latent_height latent_width): The message history
“image” (… image_channel height width): The image
“message” (… channel position latent_height latent_width), optional: The most recent message
“ignore_message” (…), optional: Whether to ignore the message
Output:
“image_level_repr” (… d_representation): The output image-level representations.
“latent_pixel_level_repr” (… latent_height latent_width d_representation): The output latent-pixel-level representations.
Methods Summary
__init__(hyper_params, settings, agent_name, ...)Initialize internal Module state, shared by both nn.Module and ScriptModule.
Initialise the module weights.
_run_recorder_hook(hooks, hook_name, output)forward(data)Return dummy outputs.
Get the state of the agent part as a dict.
set_state(checkpoint)Set the state of the agent from a checkpoint.
to(device)Move the agent to the given device.
Attributes
T_destinationagent_idThe ID of the agent.
agent_level_in_keysagent_level_out_keyscall_super_initdump_patchesenv_level_in_keysenv_level_out_keysin_keysThe keys required by the module.
is_proverWhether the agent is a prover.
is_verifierWhether the agent is a verifier.
max_message_roundsThe maximum number of message rounds in the protocol.
num_visible_message_channelsThe number of message channels visible to the agent.
out_keysThe keys produced by the module.
out_keys_sourcerequired_pretrained_modelsThe pretrained models used by the agent.
visible_message_channel_indicesThe indices of the message channels visible to the agent.
visible_message_channel_maskThe mask for the message channels visible to the agent.
visible_message_channel_namesThe names of the message channels visible to the agent.
agent_paramstrainingMethods
- __init__(hyper_params: HyperParameters, settings: ExperimentSettings, agent_name: str, protocol_handler: ProtocolHandler)[source]#
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(data: TensorDictBase) TensorDict[source]#
Return dummy outputs.
- Parameters:
data (TensorDictBase) – The data to run the body on.
- Returns:
out (TensorDict) – The dummy outputs.
- get_state_dict() dict[source]#
Get the state of the agent part as a dict.
This method should be implemented by subclasses capable of saving their state.
- Returns:
state_dict (dict) – The state of the agent part.
- set_state(checkpoint: AgentState)[source]#
Set the state of the agent from a checkpoint.
This method should be overridden by subclasses to restore the state of the agent from a checkpoint.
- Parameters:
checkpoint (AgentCheckpoint) – The checkpoint to restore the state from.