nip.code_validation.environment.CodeValidationEnvironment#
- class nip.code_validation.environment.CodeValidationEnvironment(hyper_params: HyperParameters, settings: ExperimentSettings, dataset: Dataset, protocol_handler: ProtocolHandler, *, train: bool = True)[source]#
The RL environment for code validation.
Methods Summary
__init__
(hyper_params, settings, dataset, ...)_masked_reset
(env_state, mask, data_batch)Reset the environment for a subset of the episodes.
add_dummy_actions_and_next_to_state
(state_env)Complete a done state with dummy actions and dummy next state.
get_datapoint_from_env_state_as_dict
(env_state)Get the datapoint from a single-element environment state as a dictionary.
get_next_state_from_state
(state_env)Get the next state environment from the current state environment.
prompt_array_to_list
(prompt_array)Convert a prompt in the form of a numpy array to a list of dictionaries.
prompt_list_to_array
(prompt_list)Convert a prompt in the form of a list of dictionaries to a numpy array.
reset
([env_state, data_batch])Reset the pure text environment.
step
(env_state)Take a step in the environment.
zero
()Return a zeroed environment state.
Attributes
action_spec
The specification for the action keys.
batch_size
done_spec
The specification for the done keys (done and terminated).
frames_per_batch
The number of frames to sample per training iteration.
max_prompt_messages
The maximum number messages which can be sent in a prompt to an agent.
num_envs
The number of batched environments.
observation_spec
The specification for the observation keys.
reward_spec
The specification for the agent reward keys.
state_spec
The specification for the state keys.
steps_per_env_per_iteration
The number of steps per batched environment in each iteration.
dataset
Methods
- __init__(hyper_params: HyperParameters, settings: ExperimentSettings, dataset: Dataset, protocol_handler: ProtocolHandler, *, train: bool = True)[source]#
- _masked_reset(env_state: NestedArrayDict, mask: ndarray[Any, dtype[bool]], data_batch: NestedArrayDict) NestedArrayDict [source]#
Reset the environment for a subset of the episodes.
Takes a new sample from the dataset and inserts it into the given episodes. Also resets the other elements of the episodes.
- Parameters:
env_state (NestedArrayDict) – The current observation, state and done signal.
mask (ndarray[Any, dtype[bool]]) – A boolean mask of the episodes to reset.
data_batch (NestedArrayDict) – The data batch to insert into the episodes.
- Returns:
env_state (NestedArrayDict) – The reset environment tensordict.
- add_dummy_actions_and_next_to_state(state_env: NestedArrayDict) NestedArrayDict [source]#
Complete a done state with dummy actions and dummy next state.
This method adds dummy actions and copies the current state to the next state.
It is used to complete the state when the environment is done.
- Parameters:
state_env (NestedArrayDict) – The current state environment. This is modified in place.
- Returns:
state_env (NestedArrayDict) – The modified state environment.
- get_datapoint_from_env_state_as_dict(env_state: NestedArrayDict) dict [source]#
Get the datapoint from a single-element environment state as a dictionary.
This returns a dictionary which specifies the datapoint for the environment state.
- Parameters:
env_state (NestedArrayDict) – The environment state.
- Returns:
datapoint (dict) – The datapoint.
- get_next_state_from_state(state_env: NestedArrayDict) NestedArrayDict [source]#
Get the next state environment from the current state environment.
The current state environment should contain the “next” sub-dictionary, which contains the next observation, state, reward, and done signal.
- Parameters:
state_env (NestedArrayDict) – The current state environment, which should contain the “next” sub-dictionary.
- Returns:
next_state_env (NestedArrayDict) – The next state environment.
- prompt_array_to_list(prompt_array: Annotated[ndarray[Any, dtype[_ScalarType_co]], 'message field']) list[PromptMessage] [source]#
Convert a prompt in the form of a numpy array to a list of dictionaries.
Each row of the numpy array corresponds to a message in the prompt, and each column corresponds to a field of the message.
The prompt array has a fixed number of rows, but the prompt may be shorter. If any required field is None in a row, we take that to indicate the end of the prompt.
- Parameters:
prompt_array (String[NDArray, "message field"]) – The numpy array to convert.
- Returns:
prompt_list (list[PromptMessage]) – The list of prompts.
- prompt_list_to_array(prompt_list: list[PromptMessage]) Annotated[ndarray[Any, dtype[_ScalarType_co]], 'message field'] [source]#
Convert a prompt in the form of a list of dictionaries to a numpy array.
Each element of the list is a dictionary with keys defined in
PromptMessage
. We convert this to a numpy array with columns corresponding to the keys inPromptMessage
.- Parameters:
prompt_list (list[PromptMessage]) – The list of prompts to convert.
- reset(env_state: NestedArrayDict | None = None, data_batch: NestedArrayDict | None = None) NestedArrayDict [source]#
Reset the pure text environment.
This method resets the environment for the episodes which are done. It samples a new batch of data for these episodes and calls
_masked_reset
to reset the episodes.- Parameters:
env_state (Optional[NestedArrayDict]) – The current environment state.
data_batch (Optional[NestedArrayDict]) – The data batch to use for the episodes that are done.
- Returns:
env_state (NestedArrayDict) – The reset environment state.
- step(env_state: NestedArrayDict) NestedArrayDict [source]#
Take a step in the environment.
- Parameters:
env_state (NestedArrayDict) – The current observation, state and done signal.
- Returns:
env_state (NestedArrayDict) – The input dict with a “next” sub-dict with the next observation, state, reward, and done signal.
- zero() NestedArrayDict [source]#
Return a zeroed environment state.
- Returns:
env_state (NestedArrayDict) – The zeroed environment state.