nip.image_classification.environment.ImageClassificationEnvironment#
- class nip.image_classification.environment.ImageClassificationEnvironment(*args, **kwargs)[source]#
The image classification RL environment.
- Parameters:
hyper_params (HyperParameters) – The parameters of the experiment.
settings (ExperimentSettings) – The settings of the experiment.
dataset (TensorDictDataset) – The dataset for the environment.
protocol_handler (ProtocolHandler) – The protocol handler for the environment.
train (bool, optional) – Whether the environment is used for training or evaluation.
Methods Summary
__init__(hyper_params, settings, dataset, ...)Initialize internal Module state, shared by both nn.Module and ScriptModule.
Compute the new message history and next message for given keys.
Get the specification of the agent actions.
Get the specification of the agent done signals.
Get the specification of the agent observations.
Get the specification of the agent rewards.
Get the specification of the states space.
_masked_reset(env_td, mask, data_batch)Reset the environment for a subset of the episodes.
_reset([env_td])Reset the environment (partially).
_set_seed(seed)_step(env_td)Perform a step in the environment.
Attributes
T_destination_filtered_reset_keysReturns only the effective reset keys, discarding nested resets if they're not being used.
_has_dynamic_specs_simple_done_step_mdpaction_keyThe action key of an environment.
action_keysThe action keys of an environment.
action_specThe
actionspec.batch_lockedWhether the environment can be used with a batch size different from the one it was initialized with or not.
batch_sizeNumber of envs batched in this environment instance organised in a torch.Size() object.
call_super_initdataset_num_channelsThe number of image channels in the dataset.
devicedone_keyThe done key of an environment.
done_keysThe done keys of an environment.
done_keys_groupsA list of done keys, grouped as the reset keys.
done_specThe
donespec.dump_patchesframes_per_batchThe number of frames to sample per training iteration.
full_action_specThe full action spec.
full_done_specThe full done spec.
full_observation_specfull_reward_specThe full reward spec.
full_state_specThe full state spec.
image_heightThe height of the images in the dataset.
image_widthThe width of the images in the dataset.
initial_num_channelsThe initial number of image channels in the network.
input_specInput spec.
latent_heightThe height of the latent space.
latent_num_channelsThe number of channels in the latent space.
latent_widthThe width of the latent space.
main_message_out_keymain_message_space_shapeThe shape of the main message space.
message_history_shapeThe shape of the message history and 'x' tensors.
ndimnum_block_groupsThe number of block groups in the network.
num_envsThe number of batched environments.
observation_specObservation spec.
output_specOutput spec.
reset_keysReturns a list of reset keys.
reward_keyThe reward key of an environment.
reward_keysThe reward keys of an environment.
reward_specThe
rewardspec.run_type_checksshapeEquivalent to
batch_size.specsReturns a Composite container where all the environment are present.
splitThe split of the dataset used for the environment.
state_keysThe state keys of an environment.
state_specState spec.
steps_per_env_per_iterationThe number of steps per batched environment in each iteration.
datasettrainingMethods
- __init__(hyper_params: HyperParameters, settings: ExperimentSettings, dataset: TensorDictDataset, protocol_handler: ProtocolHandler)[source]#
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- _compute_message_history_and_next_message(env_td: TensorDictBase, next_td: TensorDictBase, *, message_out_key: str, message_in_key: str, message_history_key: str, message_shape: tuple[int, ...]) TensorDictBase[source]#
Compute the new message history and next message for given keys.
This is a generic method for updating one-hot encoded next message and message history tensors given a choice of message for each agent.
Used in the
_stepmethod of the environment.- Parameters:
env_td (TensorDictBase) – The current observation and state.
next_td (TensorDictBase) – The ‘next’ tensordict, to be updated with the message history and next message.
message_out_key (str) – The key in the ‘agents’ sub-tensordict which contains the message selected by each agent. This results from the output of the agent’s forward pass.
message_in_key (str) – The key which contains the next message to be sent, which is used as input to each agent.
message_history_key (str) – The key which contains the message history tensor.
message_shape (tuple[int, ...]) – The shape of the message space.
- Returns:
next_td (TensorDictBase) – The updated ‘next’ tensordict.
- _get_action_spec() Composite[source]#
Get the specification of the agent actions.
Each action space has shape (batch_size, num_agents). Each agent chooses both a latent pixel and a decision: reject, accept or continue (represented as 0, 1 or 2).
- Returns:
action_spec (Composite) – The action specification.
- _get_done_spec() TensorSpec[source]#
Get the specification of the agent done signals.
We have both shared and agent-specific done signals. This is for convenience, where the shared done signal indicates that all relevant agents are done and so the environment should be reset.
- Returns:
done_spec (TensorSpec) – The done specification.
- _get_observation_spec() Composite[source]#
Get the specification of the agent observations.
Agents see the image and the messages sent so far. The “message” field contains the most recent message.
- Returns:
observation_spec (Composite) – The observation specification.
- _get_reward_spec() TensorSpec[source]#
Get the specification of the agent rewards.
- Returns:
reward_spec (TensorSpec) – The reward specification.
- _get_state_spec() TensorSpec[source]#
Get the specification of the states space.
Defaults to the true label.
- Returns:
state_spec (TensorSpec) – The state specification.
- _masked_reset(env_td: TensorDictBase, mask: Tensor, data_batch: TensorDict) TensorDictBase[source]#
Reset the environment for a subset of the episodes.
Takes a new sample from the dataset and inserts it into the given episodes. Also resets the other elements of the episodes.
- Parameters:
env_td (TensorDictBase) – The current observation, state and done signal.
mask (torch.Tensor) – A boolean mask of the episodes to reset.
data_batch (TensorDict) – The data batch to insert into the episodes.
- Returns:
env_td (TensorDictBase) – The reset environment tensordict.
- _reset(env_td: TensorDictBase | None = None) TensorDictBase[source]#
Reset the environment (partially).
For each episode which is done, takes a new sample from the dataset and resets the episode.
- Parameters:
env_td (Optional[TensorDictBase]) – The current observation, state and done signal.
- Returns:
env_td (TensorDictBase) – The reset environment tensordict.
- _step(env_td: TensorDictBase) TensorDictBase[source]#
Perform a step in the environment.
- Parameters:
env_td (TensorDictBase) – The current observation and state.
- Returns:
next_td (TensorDictBase) – The next observation, state, reward, and done signal.