nip.scenario_base.agents

nip.scenario_base.agents#

Base classes for building agents.

An agent is composed of a body and one or more heads. The body computes a representation of the environment state, and the heads use this representation to compute the agent’s policy, value function, etc.

All modules are TensorDictModules, which means they take and return TensorDicts. Input and output keys are specified in the module’s input_keys and output_keys attributes.

Classes

Agent(hyper_params, agent_name[, whole, ...])

A base class for holding all the parts of an agent for an experiment.

AgentBody(hyper_params, settings, ...)

Base class for all agent bodies, which compute representations for heads.

AgentHead(hyper_params, settings, ...)

Base class for all agent heads.

AgentHooks()

Holder for hooks to run at various points in the agent forward pass.

AgentPart(hyper_params, settings, ...)

Base class for all agent parts: bodies and heads.

AgentPolicyHead(hyper_params, settings, ...)

Base class for all agent policy heads.

AgentState()

Base class for storing all the data needed to restore an agent.

AgentValueHead(hyper_params, settings, ...)

Base class for all agent value heads, to the value of a state.

CombinedAgentPart(hyper_params, settings, ...)

Base class for modules which combine agent parts together.

CombinedBody(*args, **kwargs)

A module which combines all the agent bodies together.

CombinedPolicyHead(*args, **kwargs)

A module which combines all the agent policy heads together.

CombinedTensorDictAgentPart(*args, **kwargs)

Base class for modules which combine agent parts together and use TensorDicts.

CombinedValueHead(*args, **kwargs)

A module which combines all the agent value heads together.

CombinedWhole(hyper_params, settings, ...)

Base class for modules which combine whole agents together.

ConstantAgentValueHead(hyper_params, ...)

A value head which returns a constant value.

DummyAgentBody(hyper_params, settings, ...)

A dummy agent body which does nothing.

PureTextCombinedWhole(hyper_params, ...)

Base class for modules which combine whole pure-text agents together.

PureTextSharedModelGroup(hyper_params, ...)

A class representing a group of pure text agents which share the same model.

PureTextSharedModelGroupState()

Base class for storing all the data needed to restore a shared model group.

PureTextWholeAgent(hyper_params, settings, ...)

Base class for whole agents which process text input and call APIs.

RandomAgentPolicyHead(hyper_params, ...)

A policy head which samples actions randomly.

RandomWholeAgent(hyper_params, settings, ...)

Base class for whole random agents.

SoloAgentHead(hyper_params, settings, ...)

Base class for all solo agent heads, which attempt the task on their own.

TensorDictAgentPartMixin(*args, **kwargs)

Mixin for agent parts which use TensorDicts as input and output.

TensorDictDummyAgentPartMixin(*args, **kwargs)

A tensordict mixin for agent parts which are dummy (e.g. random or constant).

WholeAgent(hyper_params, settings, ...)

Base class for agents which are not split into parts.