nip.scenario_base.environment.PureTextEnvironment#

class nip.scenario_base.environment.PureTextEnvironment(hyper_params: HyperParameters, settings: ExperimentSettings, dataset: Dataset, protocol_handler: ProtocolHandler, *, train: bool = True)[source]#

Base for environments which handle non-tokenised text with nested array dicts.

Methods Summary

__init__(hyper_params, settings, dataset, ...)

_masked_reset(env_state, mask, data_batch)

Reset the environment for a subset of the episodes.

add_dummy_actions_and_next_to_state(state_env)

Complete a done state with dummy actions and dummy next state.

get_datapoint_from_env_state_as_dict(env_state)

Get the datapoint from a single-element environment state as a dictionary.

get_next_state_from_state(state_env)

Get the next state environment from the current state environment.

prompt_array_to_list(prompt_array)

Convert a prompt in the form of a numpy array to a list of dictionaries.

prompt_list_to_array(prompt_list)

Convert a prompt in the form of a list of dictionaries to a numpy array.

reset([env_state, data_batch])

Reset the pure text environment.

step(env_state)

Take a step in the environment.

zero()

Return a zeroed environment state.

Attributes

action_spec

The specification for the action keys.

batch_size

The batch size of the environment.

done_spec

The specification for the done keys (done and terminated).

frames_per_batch

The number of frames to sample per training iteration.

max_prompt_messages

The maximum number messages which can be sent in a prompt to an agent.

num_envs

The number of batched environments.

observation_spec

The specification for the observation keys.

reward_spec

The specification for the agent reward keys.

state_spec

The specification for the state keys.

steps_per_env_per_iteration

The number of steps per batched environment in each iteration.

dataset

Methods

__init__(hyper_params: HyperParameters, settings: ExperimentSettings, dataset: Dataset, protocol_handler: ProtocolHandler, *, train: bool = True)[source]#
_masked_reset(env_state: NestedArrayDict, mask: ndarray[Any, dtype[bool]], data_batch: NestedArrayDict) NestedArrayDict[source]#

Reset the environment for a subset of the episodes.

Takes a new sample from the dataset and inserts it into the given episodes. Also resets the other elements of the episodes.

Parameters:
  • env_state (NestedArrayDict) – The current observation, state and done signal.

  • mask (ndarray[Any, dtype[bool]]) – A boolean mask of the episodes to reset.

  • data_batch (NestedArrayDict) – The data batch to insert into the episodes.

Returns:

env_state (NestedArrayDict) – The reset environment tensordict.

add_dummy_actions_and_next_to_state(state_env: NestedArrayDict) NestedArrayDict[source]#

Complete a done state with dummy actions and dummy next state.

This method adds dummy actions and copies the current state to the next state.

It is used to complete the state when the environment is done.

Parameters:

state_env (NestedArrayDict) – The current state environment. This is modified in place.

Returns:

state_env (NestedArrayDict) – The modified state environment.

abstract get_datapoint_from_env_state_as_dict(env_state: NestedArrayDict) dict[source]#

Get the datapoint from a single-element environment state as a dictionary.

This returns a dictionary which specifies the datapoint for the environment state.

This method should be extended by base classes to include whatever additional fields consistute the datapoint.

Parameters:

env_state (NestedArrayDict) – The environment state.

Returns:

datapoint (dict) – The datapoint as a dictionary.

get_next_state_from_state(state_env: NestedArrayDict) NestedArrayDict[source]#

Get the next state environment from the current state environment.

The current state environment should contain the “next” sub-dictionary, which contains the next observation, state, reward, and done signal.

Parameters:

state_env (NestedArrayDict) – The current state environment, which should contain the “next” sub-dictionary.

Returns:

next_state_env (NestedArrayDict) – The next state environment.

prompt_array_to_list(prompt_array: Annotated[ndarray[Any, dtype[_ScalarType_co]], 'message field']) list[PromptMessage][source]#

Convert a prompt in the form of a numpy array to a list of dictionaries.

Each row of the numpy array corresponds to a message in the prompt, and each column corresponds to a field of the message.

The prompt array has a fixed number of rows, but the prompt may be shorter. If any required field is None in a row, we take that to indicate the end of the prompt.

Parameters:

prompt_array (String[NDArray, "message field"]) – The numpy array to convert.

Returns:

prompt_list (list[PromptMessage]) – The list of prompts.

prompt_list_to_array(prompt_list: list[PromptMessage]) Annotated[ndarray[Any, dtype[_ScalarType_co]], 'message field'][source]#

Convert a prompt in the form of a list of dictionaries to a numpy array.

Each element of the list is a dictionary with keys defined in PromptMessage. We convert this to a numpy array with columns corresponding to the keys in PromptMessage.

Parameters:

prompt_list (list[PromptMessage]) – The list of prompts to convert.

reset(env_state: NestedArrayDict | None = None, data_batch: NestedArrayDict | None = None) NestedArrayDict[source]#

Reset the pure text environment.

This method resets the environment for the episodes which are done. It samples a new batch of data for these episodes and calls _masked_reset to reset the episodes.

Parameters:
  • env_state (Optional[NestedArrayDict]) – The current environment state.

  • data_batch (Optional[NestedArrayDict]) – The data batch to use for the episodes that are done.

Returns:

env_state (NestedArrayDict) – The reset environment state.

step(env_state: NestedArrayDict) NestedArrayDict[source]#

Take a step in the environment.

Parameters:

env_state (NestedArrayDict) – The current observation, state and done signal.

Returns:

env_state (NestedArrayDict) – The input dict with a “next” sub-dict with the next observation, state, reward, and done signal.

zero() NestedArrayDict[source]#

Return a zeroed environment state.

Returns:

env_state (NestedArrayDict) – The zeroed environment state.