nip.scenario_base.environment.TensorDictEnvironment#

class nip.scenario_base.environment.TensorDictEnvironment(*args, **kwargs)[source]#

The base class for all Prover-Verifier RL environments which use tensordicts.

To implement a new environment, subclass this class and implement the following attribute and methods:

  • _message_history_shape: The shape of the message history and ‘x’ tensors.

  • _get_observation_spec: The specification of the agent observations.

  • _get_action_spec: The specification of the agent actions.

  • _get_state_spec (optional): The specification of the states space.

  • _get_reward_spec (optional): The specification of the agent rewards.

  • _get_done_spec (optional): The specification of the agent done signals.

  • _step: Perform a step in the environment.

  • _compute_message_history: Compute the new message history and next message.

  • _masked_reset: Reset the environment for a subset of the episodes.

Parameters:
  • hyper_params (HyperParameters) – The parameters of the experiment.

  • settings (ExperimentSettings) – The settings of the experiment.

  • dataset (TensorDictDataset) – The dataset for the environment.

  • protocol_handler (ProtocolHandler) – The protocol handler for the environment.

  • train (bool, optional) – Whether the environment is used for training or evaluation.

Methods Summary

__init__(hyper_params, settings, dataset, ...)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

_compute_message_history_and_next_message(...)

Compute the new message history and next message for given keys.

_get_action_spec()

Get the specification of the agent actions.

_get_done_spec()

Get the specification of the agent done signals.

_get_observation_spec()

Get the specification of the agent observations.

_get_reward_spec()

Get the specification of the agent rewards.

_get_state_spec()

Get the specification of the states space.

_masked_reset(env_td, mask, data_batch)

Reset the environment for a subset of the episodes.

_reset([env_td])

Reset the environment (partially).

_set_seed(seed)

_step(env_td)

Perform a step in the environment.

Attributes

T_destination

_filtered_reset_keys

Returns only the effective reset keys, discarding nested resets if they're not being used.

_simple_done

_step_mdp

action_key

The action key of an environment.

action_keys

The action keys of an environment.

action_spec

The action spec.

batch_locked

Whether the environment can be used with a batch size different from the one it was initialized with or not.

batch_size

Number of envs batched in this environment instance organised in a torch.Size() object.

call_super_init

device

done_key

The done key of an environment.

done_keys

The done keys of an environment.

done_keys_groups

A list of done keys, grouped as the reset keys.

done_spec

The done spec.

dump_patches

frames_per_batch

The number of frames to sample per training iteration.

full_action_spec

The full action spec.

full_done_spec

The full done spec.

full_observation_spec

full_reward_spec

The full reward spec.

full_state_spec

The full state spec.

input_spec

Input spec.

main_message_out_key

The tensordict key which contains the main message sent by each agent.

main_message_space_shape

The shape of the main message space used by the agents to communicate.

message_history_shape

The shape of the message history and 'x' tensors.

ndim

num_envs

The number of batched environments.

observation_spec

Observation spec.

output_spec

Output spec.

reset_keys

Returns a list of reset keys.

reward_key

The reward key of an environment.

reward_keys

The reward keys of an environment.

reward_spec

The reward spec.

run_type_checks

shape

Equivalent to batch_size.

specs

Returns a Composite container where all the environment are present.

state_spec

State spec.

steps_per_env_per_iteration

The number of steps per batched environment in each iteration.

dataset

training

Methods

__init__(hyper_params: HyperParameters, settings: ExperimentSettings, dataset: TensorDictDataset, protocol_handler: ProtocolHandler, *, train: bool = True)[source]#

Initialize internal Module state, shared by both nn.Module and ScriptModule.

_compute_message_history_and_next_message(env_td: TensorDictBase, next_td: TensorDictBase, *, message_out_key: str, message_in_key: str, message_history_key: str, message_shape: tuple[int, ...]) TensorDictBase[source]#

Compute the new message history and next message for given keys.

This is a generic method for updating one-hot encoded next message and message history tensors given a choice of message for each agent.

Used in the _step method of the environment.

Parameters:
  • env_td (TensorDictBase) – The current observation and state.

  • next_td (TensorDictBase) – The ‘next’ tensordict, to be updated with the message history and next message.

  • message_out_key (str) – The key in the ‘agents’ sub-tensordict which contains the message selected by each agent. This results from the output of the agent’s forward pass.

  • message_in_key (str) – The key which contains the next message to be sent, which is used as input to each agent.

  • message_history_key (str) – The key which contains the message history tensor.

  • message_shape (tuple[int, ...]) – The shape of the message space.

Returns:

next_td (TensorDictBase) – The updated ‘next’ tensordict.

abstract _get_action_spec() TensorSpec[source]#

Get the specification of the agent actions.

Subclasses should call this method and add any additional action spaces.

Returns:

action_spec (TensorSpec) – The action specification.

_get_done_spec() TensorSpec[source]#

Get the specification of the agent done signals.

We have both shared and agent-specific done signals. This is for convenience, where the shared done signal indicates that all relevant agents are done and so the environment should be reset.

Returns:

done_spec (TensorSpec) – The done specification.

abstract _get_observation_spec() TensorSpec[source]#

Get the specification of the agent observations.

The observation space has the following elements:

  • round: The current round of the interaction.

  • decision_restriction: The restriction on what the verifier can decide.

    • 0: The verifier can decide anything.

    • 1: The verifier can only decide to continue interacting.

    • 2: The verifier can only make a guess.

  • x: The message history.

  • seed: A shared seed for the environment.

  • message: The next message.

  • pretrained_embeddings: The pretrained embeddings, if any. This is a nested specification, where the sub-keys are the pretrained model names.

  • linear_message_history: The linear message history, if it is included.

Returns:

observation_spec (TensorSpec) – The observation specification.

_get_reward_spec() TensorSpec[source]#

Get the specification of the agent rewards.

Returns:

reward_spec (TensorSpec) – The reward specification.

_get_state_spec() TensorSpec[source]#

Get the specification of the states space.

Defaults to the true label.

Returns:

state_spec (TensorSpec) – The state specification.

abstract _masked_reset(env_td: TensorDictBase, mask: Tensor, data_batch: TensorDict) TensorDictBase[source]#

Reset the environment for a subset of the episodes.

Takes a new sample from the dataset and inserts it into the given episodes. Also resets the other elements of the episodes.

Parameters:
  • env_td (TensorDictBase) – The current observation, state and done signal.

  • mask (torch.Tensor) – A boolean mask of the episodes to reset.

  • data_batch (TensorDict) – The data batch to insert into the episodes.

Returns:

env_td (TensorDictBase) – The reset environment tensordict.

_reset(env_td: TensorDictBase | None = None) TensorDictBase[source]#

Reset the environment (partially).

For each episode which is done, takes a new sample from the dataset and resets the episode.

Parameters:

env_td (Optional[TensorDictBase]) – The current observation, state and done signal.

Returns:

env_td (TensorDictBase) – The reset environment tensordict.

_set_seed(seed: int | None)[source]#
_step(env_td: TensorDictBase) TensorDictBase[source]#

Perform a step in the environment.

Parameters:

env_td (TensorDictBase) – The current observation and state.

Returns:

next_td (TensorDictBase) – The next observation, state, reward, and done signal.