nip.image_classification.environment.ImageClassificationEnvironment#

class nip.image_classification.environment.ImageClassificationEnvironment(*args, **kwargs)[source]#

The image classification RL environment.

Parameters:
  • hyper_params (HyperParameters) – The parameters of the experiment.

  • settings (ExperimentSettings) – The settings of the experiment.

  • dataset (TensorDictDataset) – The dataset for the environment.

  • protocol_handler (ProtocolHandler) – The protocol handler for the environment.

  • train (bool, optional) – Whether the environment is used for training or evaluation.

Methods Summary

__init__(hyper_params, settings, dataset, ...)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

_compute_message_history_and_next_message(...)

Compute the new message history and next message for given keys.

_get_action_spec()

Get the specification of the agent actions.

_get_done_spec()

Get the specification of the agent done signals.

_get_observation_spec()

Get the specification of the agent observations.

_get_reward_spec()

Get the specification of the agent rewards.

_get_state_spec()

Get the specification of the states space.

_masked_reset(env_td, mask, data_batch)

Reset the environment for a subset of the episodes.

_reset([env_td])

Reset the environment (partially).

_set_seed(seed)

_step(env_td)

Perform a step in the environment.

Attributes

T_destination

_filtered_reset_keys

Returns only the effective reset keys, discarding nested resets if they're not being used.

_simple_done

_step_mdp

action_key

The action key of an environment.

action_keys

The action keys of an environment.

action_spec

The action spec.

batch_locked

Whether the environment can be used with a batch size different from the one it was initialized with or not.

batch_size

Number of envs batched in this environment instance organised in a torch.Size() object.

call_super_init

dataset_num_channels

The number of image channels in the dataset.

device

done_key

The done key of an environment.

done_keys

The done keys of an environment.

done_keys_groups

A list of done keys, grouped as the reset keys.

done_spec

The done spec.

dump_patches

frames_per_batch

The number of frames to sample per training iteration.

full_action_spec

The full action spec.

full_done_spec

The full done spec.

full_observation_spec

full_reward_spec

The full reward spec.

full_state_spec

The full state spec.

image_height

The height of the images in the dataset.

image_width

The width of the images in the dataset.

initial_num_channels

The initial number of image channels in the network.

input_spec

Input spec.

latent_height

The height of the latent space.

latent_num_channels

The number of channels in the latent space.

latent_width

The width of the latent space.

main_message_out_key

main_message_space_shape

The shape of the main message space.

message_history_shape

The shape of the message history and 'x' tensors.

ndim

num_block_groups

The number of block groups in the network.

num_envs

The number of batched environments.

observation_spec

Observation spec.

output_spec

Output spec.

reset_keys

Returns a list of reset keys.

reward_key

The reward key of an environment.

reward_keys

The reward keys of an environment.

reward_spec

The reward spec.

run_type_checks

shape

Equivalent to batch_size.

specs

Returns a Composite container where all the environment are present.

state_spec

State spec.

steps_per_env_per_iteration

The number of steps per batched environment in each iteration.

dataset

training

Methods

__init__(hyper_params: HyperParameters, settings: ExperimentSettings, dataset: TensorDictDataset, protocol_handler: ProtocolHandler, *, train: bool = True)[source]#

Initialize internal Module state, shared by both nn.Module and ScriptModule.

_compute_message_history_and_next_message(env_td: TensorDictBase, next_td: TensorDictBase, *, message_out_key: str, message_in_key: str, message_history_key: str, message_shape: tuple[int, ...]) TensorDictBase[source]#

Compute the new message history and next message for given keys.

This is a generic method for updating one-hot encoded next message and message history tensors given a choice of message for each agent.

Used in the _step method of the environment.

Parameters:
  • env_td (TensorDictBase) – The current observation and state.

  • next_td (TensorDictBase) – The ‘next’ tensordict, to be updated with the message history and next message.

  • message_out_key (str) – The key in the ‘agents’ sub-tensordict which contains the message selected by each agent. This results from the output of the agent’s forward pass.

  • message_in_key (str) – The key which contains the next message to be sent, which is used as input to each agent.

  • message_history_key (str) – The key which contains the message history tensor.

  • message_shape (tuple[int, ...]) – The shape of the message space.

Returns:

next_td (TensorDictBase) – The updated ‘next’ tensordict.

_get_action_spec() CompositeSpec[source]#

Get the specification of the agent actions.

Each action space has shape (batch_size, num_agents). Each agent chooses both a latent pixel and a decision: reject, accept or continue (represented as 0, 1 or 2).

Returns:

action_spec (CompositeSpec) – The action specification.

_get_done_spec() TensorSpec[source]#

Get the specification of the agent done signals.

We have both shared and agent-specific done signals. This is for convenience, where the shared done signal indicates that all relevant agents are done and so the environment should be reset.

Returns:

done_spec (TensorSpec) – The done specification.

_get_observation_spec() CompositeSpec[source]#

Get the specification of the agent observations.

Agents see the image and the messages sent so far. The “message” field contains the most recent message.

Returns:

observation_spec (CompositeSpec) – The observation specification.

_get_reward_spec() TensorSpec[source]#

Get the specification of the agent rewards.

Returns:

reward_spec (TensorSpec) – The reward specification.

_get_state_spec() TensorSpec[source]#

Get the specification of the states space.

Defaults to the true label.

Returns:

state_spec (TensorSpec) – The state specification.

_masked_reset(env_td: TensorDictBase, mask: Tensor, data_batch: TensorDict) TensorDictBase[source]#

Reset the environment for a subset of the episodes.

Takes a new sample from the dataset and inserts it into the given episodes. Also resets the other elements of the episodes.

Parameters:
  • env_td (TensorDictBase) – The current observation, state and done signal.

  • mask (torch.Tensor) – A boolean mask of the episodes to reset.

  • data_batch (TensorDict) – The data batch to insert into the episodes.

Returns:

env_td (TensorDictBase) – The reset environment tensordict.

_reset(env_td: TensorDictBase | None = None) TensorDictBase[source]#

Reset the environment (partially).

For each episode which is done, takes a new sample from the dataset and resets the episode.

Parameters:

env_td (Optional[TensorDictBase]) – The current observation, state and done signal.

Returns:

env_td (TensorDictBase) – The reset environment tensordict.

_set_seed(seed: int | None)[source]#
_step(env_td: TensorDictBase) TensorDictBase[source]#

Perform a step in the environment.

Parameters:

env_td (TensorDictBase) – The current observation and state.

Returns:

next_td (TensorDictBase) – The next observation, state, reward, and done signal.