nip.parameters.agents.GraphIsomorphismAgentParameters#
- class nip.parameters.agents.GraphIsomorphismAgentParameters(agent_lr_factor: ~nip.parameters.agents.LrFactors | dict | None = None, body_lr_factor: ~nip.parameters.agents.LrFactors | dict | None = None, update_schedule: ~nip.parameters.update_schedule.AgentUpdateSchedule = ConstantUpdateSchedule(), use_manual_architecture: bool = False, normalize_message_history: bool = False, load_checkpoint_and_parameters: bool = False, checkpoint_entity: str = <factory>, checkpoint_project: str = <factory>, checkpoint_run_id: str | None = None, checkpoint_version: str = 'latest', use_orthogonal_initialisation: bool = True, orthogonal_initialisation_gain: float = 1.0, activation_function: ~typing.Literal['relu', 'tanh', 'sigmoid'] = 'tanh', num_gnn_layers: int = 5, d_gnn: int = 16, d_gin_mlp: int = 64, gnn_output_digits: int | None = None, use_dual_gnn: bool = True, num_heads: int = 4, num_transformer_layers: int = 4, d_transformer: int = 16, d_transformer_mlp: int = 64, transformer_dropout: float = 0.0, d_node_selector: int = 16, num_node_selector_layers: int = 2, d_decider: int = 16, num_decider_layers: int = 2, include_round_in_decider: bool = True, d_linear_message_selector: int = 16, num_linear_message_selector_layers: int = 2, d_value: int = 16, num_value_layers: int = 2, include_round_in_value: bool = True, use_batch_norm: bool = True, noise_sigma: float = 0.0, use_pair_invariant_pooling: bool = True, gnn_lr_factor: ~nip.parameters.agents.LrFactors | dict | None = None)[source]#
Additional parameters for agents in the graph isomorphism experiment.
- Parameters:
activation_function (ActivationType) – The activation function to use.
num_gnn_layers (int) – The number of layers in the agents’s GNN.
d_gnn (int) – The dimension of the hidden layers in the agents’s GNN and of the attention embedding.
d_gin_mlp (int) – The dimension of the hidden layers in the agents’s Graph Isomorphism Network MLP.
gnn_output_digits (int, optional) – The number of digits in the output of the agents’s GNN. If not provided, the output is not rounded.
use_dual_gnn (bool) – Whether to run two copies of the GNN in parallel, where on the first we take the features as the message history and on the second the features are all zeros.
num_heads (int) – The number of heads in the agents’s transformer.
num_transformer_layers (int) – The number of transformer layers
d_transformer (int) – The dimensionality of the transformer
d_transformer_mlp (int) – The hidden dimension of the transformer MLP
transformer_dropout (float) – The dropout value for the transformer
d_node_selector (int) – The dimension of the hidden layer in the agents’s MLP which selects a node to send as a message.
num_node_selector_layers (int) – The number of layers in the agents’s node selector MLP.
d_decider (int) – The dimension of the hidden layer in the agents’s MLP which decides whether to accept or reject.
num_decider_layers (int) – The number of layers in the agents’s decider MLP.
include_round_in_decider (bool) – Whether to include the round number in the agents’s decider MLP.
d_linear_message_selector (int) – The dimension of the hidden layer in the agents’s MLP which selects a linear message, if we’re using the linear message space.
num_linear_message_selector_layers (int) – The number of layers in the agents’s linear message selector MLP.
d_value (int) – The dimension of the hidden layer in the agents’s MLP which estimates the value function.
num_value_layers (int) – The number of layers in the agents’s value MLP.
include_round_in_value (bool) – Whether to include the round number in the agents’s value MLP.
use_batch_norm (bool) – Whether to use batch normalization in the agents’s global pooling layer.
noise_sigma (float) – The relative standard deviation of the Gaussian noise added to the agents’s graph-level representations.
use_pair_invariant_pooling (bool) – Whether to use pair-invariant pooling in the agents’s global pooling layer. This makes the agents’s graph-level representations invariant to the order of the graphs in the pair.
body_lr_factor ([LrFactors | dict], optional) – The learning rate factor for the body part of the model. The final LR for the body is obtained by multiplying this factor by the agent LR factor and the base LR. This allows updating the body at a different rate to the rest of the model.
gnn_lr_factor ([LrFactors | dict], optional) – The learning rate factor for the GNN part of the model (split across the actor and the critic). The final LR for the GNN is obtained by multiplying this factor by the body LR. This allows updating the GNN at a different rate to the rest of the model.
Methods Summary
__eq__
(other)Return self==value.
__init__
([agent_lr_factor, body_lr_factor, ...])__repr__
()Return repr(self).
_get_param_class_from_dict
(param_dict)Try to get the parameter class from a dictionary of serialised parameters.
Construct test parameters for the agent.
from_dict
(params_dict[, ignore_extra_keys])Create a parameters object from a dictionary.
get
(address)Get a value from the parameters object using a dot-separated address.
load_from_wandb_config
(wandb_config)Load the parameters from a W&B config dictionary.
to_dict
()Convert the parameters object to a dictionary.
Attributes
LOAD_PRESERVED_PARAMETERS
activation_function
agent_lr_factor
body_lr_factor
checkpoint_run_id
checkpoint_version
d_decider
d_gin_mlp
d_gnn
d_linear_message_selector
d_node_selector
d_transformer
d_transformer_mlp
d_value
gnn_lr_factor
gnn_output_digits
include_round_in_decider
include_round_in_value
is_random
load_checkpoint_and_parameters
noise_sigma
normalize_message_history
num_decider_layers
num_gnn_layers
num_heads
num_linear_message_selector_layers
num_node_selector_layers
num_transformer_layers
num_value_layers
orthogonal_initialisation_gain
transformer_dropout
update_schedule
use_batch_norm
use_dual_gnn
use_manual_architecture
use_orthogonal_initialisation
use_pair_invariant_pooling
checkpoint_entity
checkpoint_project
Methods
- __eq__(other)#
Return self==value.
- __init__(agent_lr_factor: ~nip.parameters.agents.LrFactors | dict | None = None, body_lr_factor: ~nip.parameters.agents.LrFactors | dict | None = None, update_schedule: ~nip.parameters.update_schedule.AgentUpdateSchedule = ConstantUpdateSchedule(), use_manual_architecture: bool = False, normalize_message_history: bool = False, load_checkpoint_and_parameters: bool = False, checkpoint_entity: str = <factory>, checkpoint_project: str = <factory>, checkpoint_run_id: str | None = None, checkpoint_version: str = 'latest', use_orthogonal_initialisation: bool = True, orthogonal_initialisation_gain: float = 1.0, activation_function: ~typing.Literal['relu', 'tanh', 'sigmoid'] = 'tanh', num_gnn_layers: int = 5, d_gnn: int = 16, d_gin_mlp: int = 64, gnn_output_digits: int | None = None, use_dual_gnn: bool = True, num_heads: int = 4, num_transformer_layers: int = 4, d_transformer: int = 16, d_transformer_mlp: int = 64, transformer_dropout: float = 0.0, d_node_selector: int = 16, num_node_selector_layers: int = 2, d_decider: int = 16, num_decider_layers: int = 2, include_round_in_decider: bool = True, d_linear_message_selector: int = 16, num_linear_message_selector_layers: int = 2, d_value: int = 16, num_value_layers: int = 2, include_round_in_value: bool = True, use_batch_norm: bool = True, noise_sigma: float = 0.0, use_pair_invariant_pooling: bool = True, gnn_lr_factor: ~nip.parameters.agents.LrFactors | dict | None = None) None #
- __repr__()#
Return repr(self).
- classmethod _get_param_class_from_dict(param_dict: dict) type[ParameterValue] | None [source]#
Try to get the parameter class from a dictionary of serialised parameters.
- Parameters:
param_dict (dict) – A dictionary of parameters, which may have come from a
to_dict
method. This dictionary may contain a_type
key, which is used to determine the class of the parameter.- Returns:
param_class (type[ParameterValue] | None) – The class of the parameter, if it can be determined.
- Raises:
ValueError – If the class specified in the dictionary is not a valid parameter class.
- classmethod construct_test_params() GraphIsomorphismAgentParameters [source]#
Construct test parameters for the agent.
We use a simple architecture with one GNN layer and one transformer layer.
- Returns:
test_params (GraphIsomorphismAgentParameters) – The test parameters.
- classmethod from_dict(params_dict: dict, ignore_extra_keys: bool = False) AgentsParameters [source]#
Create a parameters object from a dictionary.
- get(address: str) Any [source]#
Get a value from the parameters object using a dot-separated address.