nip.utils.torch.GIN#

class nip.utils.torch.GIN(*args, **kwargs)[source]#

A graph isomorphism network (GIN) layer [XHLJ18].

This is a message-passing layer that aggregates the features of the neighbours as follows:

\[x_i' = \text{MLP}((1 + \epsilon) x_i + \sum_{j \in \mathcal{N}(i)} x_j)\]

where \(x_i\) is the feature vector of node \(i\), \(\mathcal{N}(i)\) is the set of neighbours of node \(i\), and \(\epsilon\) is a (possibly learnable) parameter.

The difference between this implementation and the one in PyTorch Geometric is that this one takes as input a TensorDict with dense representations of the graphs and features.

Parameters:
  • mlp (nn.Module) – The MLP to apply to the aggregated features.

  • eps (float, default=0.0) – The initial value of \(\epsilon\).

  • train_eps (bool, default=False) – Whether to train \(\epsilon\) or keep it fixed.

  • feature_in_key (NestedKey, default="x") – The key of the input features in the input TensorDict.

  • feature_out_key (NestedKey, default="x") – The key of the output features in the output TensorDict.

  • adjacency_key (NestedKey, default="adjacency") – The key of the adjacency matrix in the input TensorDict.

  • node_mask_key (NestedKey, default="node_mask") – The key of the node mask in the input TensorDict.

  • vmap_compatible (bool, default=False) – Whether the module is compatible with vmap or not. If True, the node mask is only applied after the MLP, which is less efficient but allows for the use of vmap.

Shapes

Takes as input a TensorDict with the following keys:

  • “x” - Float[”… max_nodes feature”] - The features of the nodes.

  • “adjacency” - Float[”… max_nodes max_nodes”] - The adjacency matrix of the graph.

  • “node_mask” - Bool[”… max_nodes”] - A mask indicating which nodes exist

Methods Summary

__init__(mlp[, eps, train_eps, ...])

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(tensordict)

Apply the GIN layer to the input TensorDict.

reset_parameters()

Reset the parameters of the layer.

Attributes

T_destination

call_super_init

dump_patches

in_keys

The keys of the input TensorDict.

out_keys

The keys of the output TensorDict.

out_keys_source

training

Methods

__init__(mlp: Module, eps: float = 0.0, train_eps: bool = False, feature_in_key: str | Tuple[str, ...] = 'x', feature_out_key: str | Tuple[str, ...] = 'x', adjacency_key: str | Tuple[str, ...] = 'adjacency', node_mask_key: str | Tuple[str, ...] = 'node_mask', vmap_compatible: bool = False)[source]#

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(tensordict: TensorDictBase) TensorDict[source]#

Apply the GIN layer to the input TensorDict.

Parameters:

tensordict (TensorDictBase) – The input TensorDict with a dense representation of the graph. It should have a key for the features, adjacency matrix and (optionally) node mask.

Returns:

out (TensorDict) – The input TensorDict with the GIN layer applied. This includes the updated features.

reset_parameters()[source]#

Reset the parameters of the layer.