nip.graph_isomorphism.agents.GraphIsomorphismAgent#

class nip.graph_isomorphism.agents.GraphIsomorphismAgent(hyper_params: dataclasses.InitVar[HyperParameters], agent_name: dataclasses.InitVar[str], whole: WholeAgent | None = None, body: AgentBody | None = None, policy_body: AgentBody | None = None, value_body: AgentBody | None = None, policy_head: AgentPolicyHead | None = None, value_head: AgentValueHead | None = None, solo_head: SoloAgentHead | None = None)[source]#

An agent for the graph isomorphism task.

This is a dataclass which contains all the parts of the agent.

Methods Summary

__eq__(other)

Return self==value.

__init__(hyper_params, agent_name[, whole, ...])

__post_init__(hyper_params, agent_name)

__repr__()

Return repr(self).

_append_filtered_params(model_param_dict, ...)

Filter the parameters and set their learning rate, and append them to a list.

_body_param_regex(part)

_non_body_param_regex(part)

eval()

Set the agent to evaluation mode.

filter_actor_named_parameters(named_parameters)

Filter the actor parameters from an iterable of named parameters.

get_model_parameter_dicts(base_lr[, ...])

Get the Torch parameters of the agent, and their learning rates.

get_state()

Get a checkpoint of the agent's state.

set_state(checkpoint)

Set the state of the agent from a checkpoint.

train()

Set the agent to training mode.

Attributes

_body_named_parameters

_body_parameters

body

message_logits_key

policy_body

policy_head

solo_head

value_body

value_head

whole

agent_params

hyper_params

agent_name

Methods

__eq__(other)#

Return self==value.

__init__(hyper_params: dataclasses.InitVar[HyperParameters], agent_name: dataclasses.InitVar[str], whole: WholeAgent | None = None, body: AgentBody | None = None, policy_body: AgentBody | None = None, value_body: AgentBody | None = None, policy_head: AgentPolicyHead | None = None, value_head: AgentValueHead | None = None, solo_head: SoloAgentHead | None = None) None#
__post_init__(hyper_params: HyperParameters, agent_name: str)[source]#
__repr__()#

Return repr(self).

static _append_filtered_params(model_param_dict: list[dict[str, Any]], named_parameters: list[tuple[str, Parameter]], filter: Callable[[str], bool], lr: float)[source]#

Filter the parameters and set their learning rate, and append them to a list.

Normally appends a dictionary with the keys hyper_params and lr, consisting of the filtered parameters and their learning rate. If the learning rate is 0, the parameters are frozen instead.

Parameters:
  • model_param_dict (list[dict[str, Any]]) – The list of parameter dictionaries to append to.

  • named_parameters (list[tuple[str, TorchParameter]]) – A list of the named parameters.

  • filter (Callable[[str], bool]) – A function which returns True for the parameters to include.

  • lr (float) – The learning rate for the parameters.

_body_param_regex(part: str) str[source]#
_non_body_param_regex(part: str) str[source]#
eval()[source]#

Set the agent to evaluation mode.

filter_actor_named_parameters(named_parameters: Iterable[tuple[str, Parameter]]) Iterator[tuple[str, Parameter]][source]#

Filter the actor parameters from an iterable of named parameters.

This is useful for extracting the agent’s actor parameters from an iterable of named parameters obtained by calling named_parameters() on a loss module.

Parameters:

named_parameters (Iterable[tuple[str, TorchParameter]]) – The named parameters to filter.

Yields:

actor_named_parameter (tuple[str, TorchParameter]) – A tuple of the name and the parameter of the agent’s actor parameter.

get_model_parameter_dicts(base_lr: float, named_parameters: Iterable[tuple[str, Parameter]] | None = None, body_lr_factor_override: bool = False) Iterable[dict[str, Any]][source]#

Get the Torch parameters of the agent, and their learning rates.

Parameters:
  • base_lr (float) – The base learning rate for the trainer.

  • named_parameters (Iterable[tuple[str, TorchParameter]], optional) – The named parameters of the loss module, usually obtained by loss_module.named_parameters(). If not given, the parameters of all the agent parts are used.

  • body_lr_factor_override (bool) – If true, this overrides the learning rate factor for the body (for both the actor and critic), effectively setting it to 1.

Returns:

param_dict (Iterable[dict[str, Any]]) – The Torch parameters of the agent, and their learning rates. This is an iterable of dictionaries with the keys hyper_params and lr.

get_state() AgentState[source]#

Get a checkpoint of the agent’s state.

This method gets a checkpoint of the state of all the agent parts.

Returns:

checkpoint (AgentCheckpoint) – The checkpoint of the agent’s state.

set_state(checkpoint: AgentState)[source]#

Set the state of the agent from a checkpoint.

This method restores the state of all the agent parts from the checkpoint.

Parameters:

checkpoint (AgentCheckpoint) – The checkpoint to restore the state from.

train()[source]#

Set the agent to training mode.