nip.scenario_base.agents.CombinedValueHead#
- class nip.scenario_base.agents.CombinedValueHead(*args, **kwargs)[source]#
A module which combines all the agent value heads together.
- Parameters:
hyper_params (HyperParameters) – The parameters of the experiment.
settings (ExperimentSettings) – The settings of the experiment.
protocol_handler (ProtocolHandler) – The protocol handler for the experiment.
value_heads (dict[str, AgentValueHead]) – The agent value heads to combine.
Methods Summary
__init__(hyper_params, settings, ...)Initialize internal Module state, shared by both nn.Module and ScriptModule.
Restrict an agent's input to its visible message channels.
forward(data)Forward pass through the combined value head.
Attributes
T_destinationadditional_in_keysadditional_out_keyscall_super_initdeviceThe device used by the agent part.
dump_patchesexcluded_in_keysexcluded_out_keysin_keysThe keys required by the module.
out_keysThe keys produced by the module.
out_keys_sourcetrainingMethods
- __init__(hyper_params: HyperParameters, settings: ExperimentSettings, protocol_handler: ProtocolHandler, value_heads: dict[str, AgentValueHead])[source]#
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- _restrict_input_to_visible_channels(agent_name: str, input_array: Tensor | ndarray[Any, dtype[_ScalarType_co]], shape_spec: str) Tensor[source]#
Restrict an agent’s input to its visible message channels.
Agents only receive messages from the channels they can see. This function restricts the input to the agent to only the visible message channels.
- Parameters:
- Returns:
restricted_input (Tensor | NDArray) – The input restricted to the visible message channels.
- abstract forward(data: TensorDictBase) TensorDict[source]#
Forward pass through the combined value head.
- Parameters:
data (TensorDict) – The input to the combined value head.
- Returns:
value_output (TensorDict) – The output of the combined value head.