nip.code_validation.agents.CodeValidationAgent#
- class nip.code_validation.agents.CodeValidationAgent(hyper_params: dataclasses.InitVar[HyperParameters], agent_name: dataclasses.InitVar[str], whole: WholeAgent | None = None, body: AgentBody | None = None, policy_body: AgentBody | None = None, value_body: AgentBody | None = None, policy_head: AgentPolicyHead | None = None, value_head: AgentValueHead | None = None, solo_head: SoloAgentHead | None = None)[source]#
A class representing a code validation agent.
This is a dataclass which holds all the agent parts.
Methods Summary
__eq__
(other)Return self==value.
__init__
(hyper_params, agent_name[, whole, ...])__post_init__
(hyper_params, agent_name)__repr__
()Return repr(self).
_append_filtered_params
(model_param_dict, ...)Filter the parameters and set their learning rate, and append them to a list.
_body_param_regex
(part)_non_body_param_regex
(part)eval
()Set the agent to evaluation mode.
filter_actor_named_parameters
(named_parameters)Filter the actor parameters from an iterable of named parameters.
get_model_parameter_dicts
(base_lr[, ...])Get the Torch parameters of the agent, and their learning rates.
Get a checkpoint of the agent's state.
set_state
(checkpoint)Set the state of the agent from a checkpoint.
train
()Set the agent to training mode.
Attributes
_body_named_parameters
_body_parameters
body
policy_body
policy_head
solo_head
value_body
value_head
whole
agent_params
hyper_params
agent_name
message_logits_key
Methods
- __eq__(other)#
Return self==value.
- __init__(hyper_params: dataclasses.InitVar[HyperParameters], agent_name: dataclasses.InitVar[str], whole: WholeAgent | None = None, body: AgentBody | None = None, policy_body: AgentBody | None = None, value_body: AgentBody | None = None, policy_head: AgentPolicyHead | None = None, value_head: AgentValueHead | None = None, solo_head: SoloAgentHead | None = None) None #
- __post_init__(hyper_params: HyperParameters, agent_name: str)[source]#
- __repr__()#
Return repr(self).
- static _append_filtered_params(model_param_dict: list[dict[str, Any]], named_parameters: list[tuple[str, Parameter]], filter: Callable[[str], bool], lr: float)[source]#
Filter the parameters and set their learning rate, and append them to a list.
Normally appends a dictionary with the keys
hyper_params
andlr
, consisting of the filtered parameters and their learning rate. If the learning rate is 0, the parameters are frozen instead.- Parameters:
model_param_dict (list[dict[str, Any]]) – The list of parameter dictionaries to append to.
named_parameters (list[tuple[str, TorchParameter]]) – A list of the named parameters.
filter (Callable[[str], bool]) – A function which returns True for the parameters to include.
lr (float) – The learning rate for the parameters.
- filter_actor_named_parameters(named_parameters: Iterable[tuple[str, Parameter]]) Iterator[tuple[str, Parameter]] [source]#
Filter the actor parameters from an iterable of named parameters.
This is useful for extracting the agent’s actor parameters from an iterable of named parameters obtained by calling
named_parameters()
on a loss module.
- get_model_parameter_dicts(base_lr: float, named_parameters: Iterable[tuple[str, Parameter]] | None = None, body_lr_factor_override: bool = False) Iterable[dict[str, Any]] [source]#
Get the Torch parameters of the agent, and their learning rates.
- Parameters:
base_lr (float) – The base learning rate for the trainer.
named_parameters (Iterable[tuple[str, TorchParameter]], optional) – The named parameters of the loss module, usually obtained by
loss_module.named_parameters()
. If not given, the parameters of all the agent parts are used.body_lr_factor_override (bool) – If true, this overrides the learning rate factor for the body (for both the actor and critic), effectively setting it to 1.
- Returns:
param_dict (Iterable[dict[str, Any]]) – The Torch parameters of the agent, and their learning rates. This is an iterable of dictionaries with the keys
hyper_params
andlr
.
- get_state() AgentState [source]#
Get a checkpoint of the agent’s state.
This method gets a checkpoint of the state of all the agent parts.
- Returns:
checkpoint (AgentCheckpoint) – The checkpoint of the agent’s state.
- set_state(checkpoint: AgentState)[source]#
Set the state of the agent from a checkpoint.
This method restores the state of all the agent parts from the checkpoint.
- Parameters:
checkpoint (AgentCheckpoint) – The checkpoint to restore the state from.