nip.scenario_base.agents.TensorDictDummyAgentPartMixin#
- class nip.scenario_base.agents.TensorDictDummyAgentPartMixin(*args, **kwargs)[source]#
- A tensordict mixin for agent parts which are dummy (e.g. random or constant). - Adds a dummy parameter to the agent part, so that PyTorch can calculate gradients and so that tensordict can determine the device. - Methods Summary - __init__(hyper_params, settings, agent_name, ...)- Initialize internal Module state, shared by both nn.Module and ScriptModule. - Initialise the module weights. - _run_recorder_hook(hooks, hook_name, output)- forward(data)- Forward pass through the agent part. - Get the state of the agent part as a dict. - set_state(checkpoint)- Set the state of the agent from a checkpoint. - to(device)- Move the agent to the given device. - Attributes - T_destination- agent_id- The ID of the agent. - agent_level_in_keys- agent_level_out_keys- call_super_init- dump_patches- env_level_in_keys- env_level_out_keys- in_keys- The keys required by the module. - is_prover- Whether the agent is a prover. - is_verifier- Whether the agent is a verifier. - max_message_rounds- The maximum number of message rounds in the protocol. - num_visible_message_channels- The number of message channels visible to the agent. - out_keys- The keys produced by the module. - out_keys_source- required_pretrained_models- The pretrained models used by the agent. - visible_message_channel_indices- The indices of the message channels visible to the agent. - visible_message_channel_mask- The mask for the message channels visible to the agent. - visible_message_channel_names- The names of the message channels visible to the agent. - training- Methods - __init__(hyper_params: HyperParameters, settings: ExperimentSettings, agent_name: str, protocol_handler: ProtocolHandler)[source]#
- Initialize internal Module state, shared by both nn.Module and ScriptModule. 
 - abstract forward(data: Any) Any[source]#
- Forward pass through the agent part. - Parameters:
- data (Any) – The input to the agent part. 
- Returns:
- output (Any) – The output of the forward pass on the input. 
 
 - get_state_dict() dict[source]#
- Get the state of the agent part as a dict. - This method should be implemented by subclasses capable of saving their state. - Returns:
- state_dict (dict) – The state of the agent part. 
 
 - set_state(checkpoint: AgentState)[source]#
- Set the state of the agent from a checkpoint. - This method should be overridden by subclasses to restore the state of the agent from a checkpoint. - Parameters:
- checkpoint (AgentCheckpoint) – The checkpoint to restore the state from.