nip.graph_isomorphism.environment.GraphIsomorphismEnvironment#
- class nip.graph_isomorphism.environment.GraphIsomorphismEnvironment(*args, **kwargs)[source]#
- The graph isomorphism RL environment. - Agents see the adjacency matrix and the messages sent so far. - Parameters:
- hyper_params (HyperParameters) – The parameters of the experiment. 
- settings (ExperimentSettings) – The settings of the experiment. 
- dataset (TensorDictDataset) – The dataset for the environment. 
- protocol_handler (ProtocolHandler) – The protocol handler for the environment. 
- train (bool, optional) – Whether the environment is used for training or evaluation. 
 
 - Methods Summary - __init__(hyper_params, settings, dataset, ...)- Initialize internal Module state, shared by both nn.Module and ScriptModule. - Compute the new message history and next message for given keys. - Get the specification of the agent actions. - Get the specification of the agent done signals. - Get the specification of the agent observations. - Get the specification of the agent rewards. - Get the specification of the states space. - _masked_reset(env_td, mask, data_batch)- Reset the environment for a subset of the episodes. - _reset([env_td])- Reset the environment (partially). - _set_seed(seed)- _step(env_td)- Perform a step in the environment. - Attributes - T_destination- _filtered_reset_keys- Returns only the effective reset keys, discarding nested resets if they're not being used. - _has_dynamic_specs- _simple_done- _step_mdp- action_key- The action key of an environment. - action_keys- The action keys of an environment. - action_spec- The - actionspec.- batch_locked- Whether the environment can be used with a batch size different from the one it was initialized with or not. - batch_size- Number of envs batched in this environment instance organised in a torch.Size() object. - call_super_init- device- done_key- The done key of an environment. - done_keys- The done keys of an environment. - done_keys_groups- A list of done keys, grouped as the reset keys. - done_spec- The - donespec.- dump_patches- frames_per_batch- The number of frames to sample per training iteration. - full_action_spec- The full action spec. - full_done_spec- The full done spec. - full_observation_spec- full_reward_spec- The full reward spec. - full_state_spec- The full state spec. - input_spec- Input spec. - main_message_out_key- main_message_space_shape- The shape of the main message space. - max_num_nodes- The maximum number of nodes in a graph. - message_history_shape- The shape of the message history and 'x' tensors. - ndim- num_envs- The number of batched environments. - observation_spec- Observation spec. - output_spec- Output spec. - reset_keys- Returns a list of reset keys. - reward_key- The reward key of an environment. - reward_keys- The reward keys of an environment. - reward_spec- The - rewardspec.- run_type_checks- shape- Equivalent to - batch_size.- specs- Returns a Composite container where all the environment are present. - split- The split of the dataset used for the environment. - state_keys- The state keys of an environment. - state_spec- State spec. - steps_per_env_per_iteration- The number of steps per batched environment in each iteration. - dataset- training- Methods - __init__(hyper_params: HyperParameters, settings: ExperimentSettings, dataset: TensorDictDataset, protocol_handler: ProtocolHandler)[source]#
- Initialize internal Module state, shared by both nn.Module and ScriptModule. 
 - _compute_message_history_and_next_message(env_td: TensorDictBase, next_td: TensorDictBase, *, message_out_key: str, message_in_key: str, message_history_key: str, message_shape: tuple[int, ...]) TensorDictBase[source]#
- Compute the new message history and next message for given keys. - This is a generic method for updating one-hot encoded next message and message history tensors given a choice of message for each agent. - Used in the - _stepmethod of the environment.- Parameters:
- env_td (TensorDictBase) – The current observation and state. 
- next_td (TensorDictBase) – The ‘next’ tensordict, to be updated with the message history and next message. 
- message_out_key (str) – The key in the ‘agents’ sub-tensordict which contains the message selected by each agent. This results from the output of the agent’s forward pass. 
- message_in_key (str) – The key which contains the next message to be sent, which is used as input to each agent. 
- message_history_key (str) – The key which contains the message history tensor. 
- message_shape (tuple[int, ...]) – The shape of the message space. 
 
- Returns:
- next_td (TensorDictBase) – The updated ‘next’ tensordict. 
 
 - _get_action_spec() Composite[source]#
- Get the specification of the agent actions. - Each action space has shape (batch_size, num_agents). Each agent chooses both a node and a decision: reject, accept or continue (represented as 0, 1 or 2). The node is is a number between 0 and 2 * max_num_nodes - 1. If it is less than max_num_nodes, it is a node in the first graph, otherwise it is a node in the second graph. - Returns:
- action_spec (Composite) – The action specification. 
 
 - _get_done_spec() TensorSpec[source]#
- Get the specification of the agent done signals. - We have both shared and agent-specific done signals. This is for convenience, where the shared done signal indicates that all relevant agents are done and so the environment should be reset. - Returns:
- done_spec (TensorSpec) – The done specification. 
 
 - _get_observation_spec() Composite[source]#
- Get the specification of the agent observations. - Agents see the adjacency matrix and the messages sent so far. The “message” field contains the most recent message. - Returns:
- observation_spec (Composite) – The observation specification. 
 
 - _get_reward_spec() TensorSpec[source]#
- Get the specification of the agent rewards. - Returns:
- reward_spec (TensorSpec) – The reward specification. 
 
 - _get_state_spec() TensorSpec[source]#
- Get the specification of the states space. - Defaults to the true label. - Returns:
- state_spec (TensorSpec) – The state specification. 
 
 - _masked_reset(env_td: TensorDictBase, mask: Tensor, data_batch: TensorDict) TensorDictBase[source]#
- Reset the environment for a subset of the episodes. - Takes a new sample from the dataset and inserts it into the given episodes. Also resets the other elements of the episodes. - Parameters:
- env_td (TensorDictBase) – The current observation, state and done signal. 
- mask (torch.Tensor) – A boolean mask of the episodes to reset. 
- data_batch (TensorDict) – The data batch to insert into the episodes. 
 
- Returns:
- env_td (TensorDictBase) – The reset environment tensordict. 
 
 - _reset(env_td: TensorDictBase | None = None) TensorDictBase[source]#
- Reset the environment (partially). - For each episode which is done, takes a new sample from the dataset and resets the episode. - Parameters:
- env_td (Optional[TensorDictBase]) – The current observation, state and done signal. 
- Returns:
- env_td (TensorDictBase) – The reset environment tensordict. 
 
 - _step(env_td: TensorDictBase) TensorDictBase[source]#
- Perform a step in the environment. - Parameters:
- env_td (TensorDictBase) – The current observation and state. 
- Returns:
- next_td (TensorDictBase) – The next observation, state, reward, and done signal.