nip.code_validation.agents.CodeValidationAgent#
- class nip.code_validation.agents.CodeValidationAgent(hyper_params: dataclasses.InitVar[HyperParameters], agent_name: dataclasses.InitVar[str], whole: WholeAgent | None = None, body: AgentBody | None = None, policy_body: AgentBody | None = None, value_body: AgentBody | None = None, policy_head: AgentPolicyHead | None = None, value_head: AgentValueHead | None = None, solo_head: SoloAgentHead | None = None)[source]#
A class representing a code validation agent.
This is a dataclass which holds all the agent parts.
Methods Summary
__eq__(other)Return self==value.
__init__(hyper_params, agent_name[, whole, ...])__post_init__(hyper_params, agent_name)__repr__()Return repr(self).
_append_filtered_params(model_param_dict, ...)Filter the parameters and set their learning rate, and append them to a list.
_body_param_regex(part)_non_body_param_regex(part)eval()Set the agent to evaluation mode.
filter_actor_named_parameters(named_parameters)Filter the actor parameters from an iterable of named parameters.
get_model_parameter_dicts(base_lr[, ...])Get the Torch parameters of the agent, and their learning rates.
Get a checkpoint of the agent's state.
set_state(checkpoint)Set the state of the agent from a checkpoint.
train()Set the agent to training mode.
Attributes
_body_named_parameters_body_parametersbodypolicy_bodypolicy_headsolo_headvalue_bodyvalue_headwholeagent_paramshyper_paramsagent_namemessage_logits_keyMethods
- __eq__(other)#
Return self==value.
- __init__(hyper_params: dataclasses.InitVar[HyperParameters], agent_name: dataclasses.InitVar[str], whole: WholeAgent | None = None, body: AgentBody | None = None, policy_body: AgentBody | None = None, value_body: AgentBody | None = None, policy_head: AgentPolicyHead | None = None, value_head: AgentValueHead | None = None, solo_head: SoloAgentHead | None = None) None#
- __post_init__(hyper_params: HyperParameters, agent_name: str)[source]#
- __repr__()#
Return repr(self).
- static _append_filtered_params(model_param_dict: list[dict[str, Any]], named_parameters: list[tuple[str, Parameter]], filter: Callable[[str], bool], lr: float)[source]#
Filter the parameters and set their learning rate, and append them to a list.
Normally appends a dictionary with the keys
hyper_paramsandlr, consisting of the filtered parameters and their learning rate. If the learning rate is 0, the parameters are frozen instead.- Parameters:
model_param_dict (list[dict[str, Any]]) – The list of parameter dictionaries to append to.
named_parameters (list[tuple[str, TorchParameter]]) – A list of the named parameters.
filter (Callable[[str], bool]) – A function which returns True for the parameters to include.
lr (float) – The learning rate for the parameters.
- filter_actor_named_parameters(named_parameters: Iterable[tuple[str, Parameter]]) Iterator[tuple[str, Parameter]][source]#
Filter the actor parameters from an iterable of named parameters.
This is useful for extracting the agent’s actor parameters from an iterable of named parameters obtained by calling
named_parameters()on a loss module.
- get_model_parameter_dicts(base_lr: float, named_parameters: Iterable[tuple[str, Parameter]] | None = None, body_lr_factor_override: bool = False) Iterable[dict[str, Any]][source]#
Get the Torch parameters of the agent, and their learning rates.
- Parameters:
base_lr (float) – The base learning rate for the trainer.
named_parameters (Iterable[tuple[str, TorchParameter]], optional) – The named parameters of the loss module, usually obtained by
loss_module.named_parameters(). If not given, the parameters of all the agent parts are used.body_lr_factor_override (bool) – If true, this overrides the learning rate factor for the body (for both the actor and critic), effectively setting it to 1.
- Returns:
param_dict (Iterable[dict[str, Any]]) – The Torch parameters of the agent, and their learning rates. This is an iterable of dictionaries with the keys
hyper_paramsandlr.
- get_state() AgentState[source]#
Get a checkpoint of the agent’s state.
This method gets a checkpoint of the state of all the agent parts.
- Returns:
checkpoint (AgentCheckpoint) – The checkpoint of the agent’s state.
- set_state(checkpoint: AgentState)[source]#
Set the state of the agent from a checkpoint.
This method restores the state of all the agent parts from the checkpoint.
- Parameters:
checkpoint (AgentCheckpoint) – The checkpoint to restore the state from.