nip.scenario_base.environment.TensorDictEnvironment#
- class nip.scenario_base.environment.TensorDictEnvironment(*args, **kwargs)[source]#
The base class for all Prover-Verifier RL environments which use tensordicts.
To implement a new environment, subclass this class and implement the following attribute and methods:
_message_history_shape
: The shape of the message history and ‘x’ tensors._get_observation_spec
: The specification of the agent observations._get_action_spec
: The specification of the agent actions._get_state_spec
(optional): The specification of the states space._get_reward_spec
(optional): The specification of the agent rewards._get_done_spec
(optional): The specification of the agent done signals._step
: Perform a step in the environment._compute_message_history
: Compute the new message history and next message._masked_reset
: Reset the environment for a subset of the episodes.
- Parameters:
hyper_params (HyperParameters) – The parameters of the experiment.
settings (ExperimentSettings) – The settings of the experiment.
dataset (TensorDictDataset) – The dataset for the environment.
protocol_handler (ProtocolHandler) – The protocol handler for the environment.
train (bool, optional) – Whether the environment is used for training or evaluation.
Methods Summary
__init__
(hyper_params, settings, dataset, ...)Initialize internal Module state, shared by both nn.Module and ScriptModule.
Compute the new message history and next message for given keys.
Get the specification of the agent actions.
Get the specification of the agent done signals.
Get the specification of the agent observations.
Get the specification of the agent rewards.
Get the specification of the states space.
_masked_reset
(env_td, mask, data_batch)Reset the environment for a subset of the episodes.
_reset
([env_td])Reset the environment (partially).
_set_seed
(seed)_step
(env_td)Perform a step in the environment.
Attributes
T_destination
_filtered_reset_keys
Returns only the effective reset keys, discarding nested resets if they're not being used.
_simple_done
_step_mdp
action_key
The action key of an environment.
action_keys
The action keys of an environment.
action_spec
The
action
spec.batch_locked
Whether the environment can be used with a batch size different from the one it was initialized with or not.
batch_size
Number of envs batched in this environment instance organised in a torch.Size() object.
call_super_init
device
done_key
The done key of an environment.
done_keys
The done keys of an environment.
done_keys_groups
A list of done keys, grouped as the reset keys.
done_spec
The
done
spec.dump_patches
frames_per_batch
The number of frames to sample per training iteration.
full_action_spec
The full action spec.
full_done_spec
The full done spec.
full_observation_spec
full_reward_spec
The full reward spec.
full_state_spec
The full state spec.
input_spec
Input spec.
main_message_out_key
The tensordict key which contains the main message sent by each agent.
main_message_space_shape
The shape of the main message space used by the agents to communicate.
message_history_shape
The shape of the message history and 'x' tensors.
ndim
num_envs
The number of batched environments.
observation_spec
Observation spec.
output_spec
Output spec.
reset_keys
Returns a list of reset keys.
reward_key
The reward key of an environment.
reward_keys
The reward keys of an environment.
reward_spec
The
reward
spec.run_type_checks
shape
Equivalent to
batch_size
.specs
Returns a Composite container where all the environment are present.
state_spec
State spec.
steps_per_env_per_iteration
The number of steps per batched environment in each iteration.
dataset
training
Methods
- __init__(hyper_params: HyperParameters, settings: ExperimentSettings, dataset: TensorDictDataset, protocol_handler: ProtocolHandler, *, train: bool = True)[source]#
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- _compute_message_history_and_next_message(env_td: TensorDictBase, next_td: TensorDictBase, *, message_out_key: str, message_in_key: str, message_history_key: str, message_shape: tuple[int, ...]) TensorDictBase [source]#
Compute the new message history and next message for given keys.
This is a generic method for updating one-hot encoded next message and message history tensors given a choice of message for each agent.
Used in the
_step
method of the environment.- Parameters:
env_td (TensorDictBase) – The current observation and state.
next_td (TensorDictBase) – The ‘next’ tensordict, to be updated with the message history and next message.
message_out_key (str) – The key in the ‘agents’ sub-tensordict which contains the message selected by each agent. This results from the output of the agent’s forward pass.
message_in_key (str) – The key which contains the next message to be sent, which is used as input to each agent.
message_history_key (str) – The key which contains the message history tensor.
message_shape (tuple[int, ...]) – The shape of the message space.
- Returns:
next_td (TensorDictBase) – The updated ‘next’ tensordict.
- abstract _get_action_spec() TensorSpec [source]#
Get the specification of the agent actions.
Subclasses should call this method and add any additional action spaces.
- Returns:
action_spec (TensorSpec) – The action specification.
- _get_done_spec() TensorSpec [source]#
Get the specification of the agent done signals.
We have both shared and agent-specific done signals. This is for convenience, where the shared done signal indicates that all relevant agents are done and so the environment should be reset.
- Returns:
done_spec (TensorSpec) – The done specification.
- abstract _get_observation_spec() TensorSpec [source]#
Get the specification of the agent observations.
The observation space has the following elements:
round
: The current round of the interaction.decision_restriction
: The restriction on what the verifier can decide.0: The verifier can decide anything.
1: The verifier can only decide to continue interacting.
2: The verifier can only make a guess.
x
: The message history.seed
: A shared seed for the environment.message
: The next message.pretrained_embeddings
: The pretrained embeddings, if any. This is a nested specification, where the sub-keys are the pretrained model names.linear_message_history
: The linear message history, if it is included.
- Returns:
observation_spec (TensorSpec) – The observation specification.
- _get_reward_spec() TensorSpec [source]#
Get the specification of the agent rewards.
- Returns:
reward_spec (TensorSpec) – The reward specification.
- _get_state_spec() TensorSpec [source]#
Get the specification of the states space.
Defaults to the true label.
- Returns:
state_spec (TensorSpec) – The state specification.
- abstract _masked_reset(env_td: TensorDictBase, mask: Tensor, data_batch: TensorDict) TensorDictBase [source]#
Reset the environment for a subset of the episodes.
Takes a new sample from the dataset and inserts it into the given episodes. Also resets the other elements of the episodes.
- Parameters:
env_td (TensorDictBase) – The current observation, state and done signal.
mask (torch.Tensor) – A boolean mask of the episodes to reset.
data_batch (TensorDict) – The data batch to insert into the episodes.
- Returns:
env_td (TensorDictBase) – The reset environment tensordict.
- _reset(env_td: TensorDictBase | None = None) TensorDictBase [source]#
Reset the environment (partially).
For each episode which is done, takes a new sample from the dataset and resets the episode.
- Parameters:
env_td (Optional[TensorDictBase]) – The current observation, state and done signal.
- Returns:
env_td (TensorDictBase) – The reset environment tensordict.
- _step(env_td: TensorDictBase) TensorDictBase [source]#
Perform a step in the environment.
- Parameters:
env_td (TensorDictBase) – The current observation and state.
- Returns:
next_td (TensorDictBase) – The next observation, state, reward, and done signal.