nip.parameters.agents.GraphIsomorphismAgentParameters#

class nip.parameters.agents.GraphIsomorphismAgentParameters(agent_lr_factor: ~nip.parameters.agents.LrFactors | dict | None = None, body_lr_factor: ~nip.parameters.agents.LrFactors | dict | None = None, update_schedule: ~nip.parameters.update_schedule.AgentUpdateSchedule = ConstantUpdateSchedule(), use_manual_architecture: bool = False, normalize_message_history: bool = False, load_checkpoint_and_parameters: bool = False, checkpoint_entity: str = <factory>, checkpoint_project: str = <factory>, checkpoint_run_id: str | None = None, checkpoint_version: str = 'latest', use_orthogonal_initialisation: bool = True, orthogonal_initialisation_gain: float = 1.0, activation_function: ~typing.Literal['relu', 'tanh', 'sigmoid'] = 'tanh', num_gnn_layers: int = 5, d_gnn: int = 16, d_gin_mlp: int = 64, gnn_output_digits: int | None = None, use_dual_gnn: bool = True, num_heads: int = 4, num_transformer_layers: int = 4, d_transformer: int = 16, d_transformer_mlp: int = 64, transformer_dropout: float = 0.0, d_node_selector: int = 16, num_node_selector_layers: int = 2, d_decider: int = 16, num_decider_layers: int = 2, include_round_in_decider: bool = True, d_linear_message_selector: int = 16, num_linear_message_selector_layers: int = 2, d_value: int = 16, num_value_layers: int = 2, include_round_in_value: bool = True, use_batch_norm: bool = True, noise_sigma: float = 0.0, use_pair_invariant_pooling: bool = True, gnn_lr_factor: ~nip.parameters.agents.LrFactors | dict | None = None)[source]#

Additional parameters for agents in the graph isomorphism experiment.

Parameters:
  • activation_function (ActivationType) – The activation function to use.

  • num_gnn_layers (int) – The number of layers in the agents’s GNN.

  • d_gnn (int) – The dimension of the hidden layers in the agents’s GNN and of the attention embedding.

  • d_gin_mlp (int) – The dimension of the hidden layers in the agents’s Graph Isomorphism Network MLP.

  • gnn_output_digits (int, optional) – The number of digits in the output of the agents’s GNN. If not provided, the output is not rounded.

  • use_dual_gnn (bool) – Whether to run two copies of the GNN in parallel, where on the first we take the features as the message history and on the second the features are all zeros.

  • num_heads (int) – The number of heads in the agents’s transformer.

  • num_transformer_layers (int) – The number of transformer layers

  • d_transformer (int) – The dimensionality of the transformer

  • d_transformer_mlp (int) – The hidden dimension of the transformer MLP

  • transformer_dropout (float) – The dropout value for the transformer

  • d_node_selector (int) – The dimension of the hidden layer in the agents’s MLP which selects a node to send as a message.

  • num_node_selector_layers (int) – The number of layers in the agents’s node selector MLP.

  • d_decider (int) – The dimension of the hidden layer in the agents’s MLP which decides whether to accept or reject.

  • num_decider_layers (int) – The number of layers in the agents’s decider MLP.

  • include_round_in_decider (bool) – Whether to include the round number in the agents’s decider MLP.

  • d_linear_message_selector (int) – The dimension of the hidden layer in the agents’s MLP which selects a linear message, if we’re using the linear message space.

  • num_linear_message_selector_layers (int) – The number of layers in the agents’s linear message selector MLP.

  • d_value (int) – The dimension of the hidden layer in the agents’s MLP which estimates the value function.

  • num_value_layers (int) – The number of layers in the agents’s value MLP.

  • include_round_in_value (bool) – Whether to include the round number in the agents’s value MLP.

  • use_batch_norm (bool) – Whether to use batch normalization in the agents’s global pooling layer.

  • noise_sigma (float) – The relative standard deviation of the Gaussian noise added to the agents’s graph-level representations.

  • use_pair_invariant_pooling (bool) – Whether to use pair-invariant pooling in the agents’s global pooling layer. This makes the agents’s graph-level representations invariant to the order of the graphs in the pair.

  • body_lr_factor ([LrFactors | dict], optional) – The learning rate factor for the body part of the model. The final LR for the body is obtained by multiplying this factor by the agent LR factor and the base LR. This allows updating the body at a different rate to the rest of the model.

  • gnn_lr_factor ([LrFactors | dict], optional) – The learning rate factor for the GNN part of the model (split across the actor and the critic). The final LR for the GNN is obtained by multiplying this factor by the body LR. This allows updating the GNN at a different rate to the rest of the model.

Methods Summary

__eq__(other)

Return self==value.

__init__([agent_lr_factor, body_lr_factor, ...])

__post_init__()

__repr__()

Return repr(self).

_get_param_class_from_dict(param_dict)

Try to get the parameter class from a dictionary of serialised parameters.

construct_test_params()

Construct test parameters for the agent.

from_dict(params_dict[, ignore_extra_keys])

Create a parameters object from a dictionary.

get(address)

Get a value from the parameters object using a dot-separated address.

load_from_wandb_config(wandb_config)

Load the parameters from a W&B config dictionary.

to_dict()

Convert the parameters object to a dictionary.

Attributes

LOAD_PRESERVED_PARAMETERS

activation_function

agent_lr_factor

body_lr_factor

checkpoint_run_id

checkpoint_version

d_decider

d_gin_mlp

d_gnn

d_linear_message_selector

d_node_selector

d_transformer

d_transformer_mlp

d_value

gnn_lr_factor

gnn_output_digits

include_round_in_decider

include_round_in_value

is_random

load_checkpoint_and_parameters

noise_sigma

normalize_message_history

num_decider_layers

num_gnn_layers

num_heads

num_linear_message_selector_layers

num_node_selector_layers

num_transformer_layers

num_value_layers

orthogonal_initialisation_gain

transformer_dropout

update_schedule

use_batch_norm

use_dual_gnn

use_manual_architecture

use_orthogonal_initialisation

use_pair_invariant_pooling

checkpoint_entity

checkpoint_project

Methods

__eq__(other)#

Return self==value.

__init__(agent_lr_factor: ~nip.parameters.agents.LrFactors | dict | None = None, body_lr_factor: ~nip.parameters.agents.LrFactors | dict | None = None, update_schedule: ~nip.parameters.update_schedule.AgentUpdateSchedule = ConstantUpdateSchedule(), use_manual_architecture: bool = False, normalize_message_history: bool = False, load_checkpoint_and_parameters: bool = False, checkpoint_entity: str = <factory>, checkpoint_project: str = <factory>, checkpoint_run_id: str | None = None, checkpoint_version: str = 'latest', use_orthogonal_initialisation: bool = True, orthogonal_initialisation_gain: float = 1.0, activation_function: ~typing.Literal['relu', 'tanh', 'sigmoid'] = 'tanh', num_gnn_layers: int = 5, d_gnn: int = 16, d_gin_mlp: int = 64, gnn_output_digits: int | None = None, use_dual_gnn: bool = True, num_heads: int = 4, num_transformer_layers: int = 4, d_transformer: int = 16, d_transformer_mlp: int = 64, transformer_dropout: float = 0.0, d_node_selector: int = 16, num_node_selector_layers: int = 2, d_decider: int = 16, num_decider_layers: int = 2, include_round_in_decider: bool = True, d_linear_message_selector: int = 16, num_linear_message_selector_layers: int = 2, d_value: int = 16, num_value_layers: int = 2, include_round_in_value: bool = True, use_batch_norm: bool = True, noise_sigma: float = 0.0, use_pair_invariant_pooling: bool = True, gnn_lr_factor: ~nip.parameters.agents.LrFactors | dict | None = None) None#
__post_init__()[source]#
__repr__()#

Return repr(self).

classmethod _get_param_class_from_dict(param_dict: dict) type[ParameterValue] | None[source]#

Try to get the parameter class from a dictionary of serialised parameters.

Parameters:

param_dict (dict) – A dictionary of parameters, which may have come from a to_dict method. This dictionary may contain a _type key, which is used to determine the class of the parameter.

Returns:

param_class (type[ParameterValue] | None) – The class of the parameter, if it can be determined.

Raises:

ValueError – If the class specified in the dictionary is not a valid parameter class.

classmethod construct_test_params() GraphIsomorphismAgentParameters[source]#

Construct test parameters for the agent.

We use a simple architecture with one GNN layer and one transformer layer.

Returns:

test_params (GraphIsomorphismAgentParameters) – The test parameters.

classmethod from_dict(params_dict: dict, ignore_extra_keys: bool = False) AgentsParameters[source]#

Create a parameters object from a dictionary.

Parameters:
  • params_dict (dict) – A dictionary of the parameters.

  • ignore_extra_keys (bool, default=False) – If True, ignore keys in the dictionary that do not correspond to fields in the parameters object.

Returns:

hyper_params (AgentsParameters) – The parameters object.

get(address: str) Any[source]#

Get a value from the parameters object using a dot-separated address.

Parameters:

address (str) – The path to the value in the parameters object, separated by dots.

Returns:

value (Any) – The value at the address.

Raises:

KeyError – If the address does not exist.

load_from_wandb_config(wandb_config: dict)[source]#

Load the parameters from a W&B config dictionary.

Parameters:

wandb_config (dict) – The W&B config dictionary for this agent (e.g. wandb_run.config["agents"][agent_name]).

to_dict() dict[source]#

Convert the parameters object to a dictionary.

Adds the is_random parameter to the dictionary. This is not a field of the parameters object, but we want to include it in the dictionary for logging.

Returns:

params_dict (dict) – A dictionary of the parameters.