nip.scenario_base.agents.CombinedPolicyHead#
- class nip.scenario_base.agents.CombinedPolicyHead(*args, **kwargs)[source]#
A module which combines all the agent policy heads together.
- Parameters:
hyper_params (HyperParameters) – The parameters of the experiment.
settings (ExperimentSettings) – The settings of the experiment.
protocol_handler (ProtocolHandler) – The protocol handler for the experiment.
policy_heads (dict[str, AgentPolicyHead]) – The agent policy heads to combine.
Methods Summary
__init__(hyper_params, settings, ...)Initialize internal Module state, shared by both nn.Module and ScriptModule.
_expand_logits_to_all_channels(agent_name, ...)Expand an agent's logits from its visible message channels to all.
_restrict_decisions(decision_restriction, ...)Make sure the agent's decisions comply with the restrictions.
Restrict an agent's input to its visible message channels.
forward(data)Forward pass through the combined policy head.
Attributes
T_destinationadditional_in_keysadditional_out_keyscall_super_initdeviceThe device used by the agent part.
dump_patchesexcluded_in_keysexcluded_out_keysin_keysThe keys required by the module.
out_keysThe keys produced by the module.
out_keys_sourcetrainingMethods
- __init__(hyper_params: HyperParameters, settings: ExperimentSettings, protocol_handler: ProtocolHandler, policy_heads: dict[str, AgentPolicyHead])[source]#
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- _expand_logits_to_all_channels(agent_name: str, logits: Tensor, shape_spec: str, fill_value: float = -1000000000.0) Tensor[source]#
Expand an agent’s logits from its visible message channels to all.
Agents only output messages for the channels they can see. This function expands the output to all channels, by filling in
fill_valuefor the logits in the channels the agent cannot see.- Parameters:
agent_name (str) – The name of the agent.
logits (Tensor) – A tensor of output logits. This is a single key in the output of the agent’s forward pass.
shape_spec (str) – The shape of the output. This is a space-separated string of the dimensions of the output. One of these must be “channel”.
fill_value (float, default=-1e9) – The value to fill in for the channels the agent cannot see.
- Returns:
expanded_logits (Tensor) – The output expanded to all channels. This has the same shape as
logits, except that the channel dimension is the full set of message channels.
- _restrict_decisions(decision_restriction: Int[Tensor, '...'], decision_logits: Float[Tensor, '... agents 3']) TensorDictBase[source]#
Make sure the agent’s decisions comply with the restrictions.
- Parameters:
decision_restriction (Int[Tensor, "..."]) –
The restrictions on the agents’ decisions. The possible values are:#
0: The verifier can decide anything.
1: The verifier can only decide to continue interacting.
2: The verifier can only make a guess.
decision_logits (Float[Tensor, "... agents 3"]) – The logits for the agents’ decisions.
- Returns:
decision_logits (Float[Tensor, “… agents 3”]) – The logits for the agents’ decisions, with the restricted decisions set to -1e9.
- _restrict_input_to_visible_channels(agent_name: str, input_array: Tensor | ndarray[Any, dtype[_ScalarType_co]], shape_spec: str) Tensor[source]#
Restrict an agent’s input to its visible message channels.
Agents only receive messages from the channels they can see. This function restricts the input to the agent to only the visible message channels.
- Parameters:
- Returns:
restricted_input (Tensor | NDArray) – The input restricted to the visible message channels.
- abstract forward(data: TensorDictBase) TensorDict[source]#
Forward pass through the combined policy head.
- Parameters:
data (TensorDict) – The input to the combined policy head.
- Returns:
policy_output (TensorDict) – The output of the combined policy head. This must contain the key (“agents”, “main_message_logits”), which has shape “… agents channel position logit” and contains the logits for the agents’ messages in the main message space.