nip.utils.torch.GIN#
- class nip.utils.torch.GIN(*args, **kwargs)[source]#
A graph isomorphism network (GIN) layer [XHLJ18].
This is a message-passing layer that aggregates the features of the neighbours as follows:
\[x_i' = \text{MLP}((1 + \epsilon) x_i + \sum_{j \in \mathcal{N}(i)} x_j)\]where \(x_i\) is the feature vector of node \(i\), \(\mathcal{N}(i)\) is the set of neighbours of node \(i\), and \(\epsilon\) is a (possibly learnable) parameter.
The difference between this implementation and the one in PyTorch Geometric is that this one takes as input a TensorDict with dense representations of the graphs and features.
- Parameters:
mlp (nn.Module) – The MLP to apply to the aggregated features.
eps (float, default=0.0) – The initial value of \(\epsilon\).
train_eps (bool, default=False) – Whether to train \(\epsilon\) or keep it fixed.
feature_in_key (NestedKey, default="x") – The key of the input features in the input TensorDict.
feature_out_key (NestedKey, default="x") – The key of the output features in the output TensorDict.
adjacency_key (NestedKey, default="adjacency") – The key of the adjacency matrix in the input TensorDict.
node_mask_key (NestedKey, default="node_mask") – The key of the node mask in the input TensorDict.
vmap_compatible (bool, default=False) – Whether the module is compatible with
vmap
or not. IfTrue
, the node mask is only applied after the MLP, which is less efficient but allows for the use ofvmap
.
Shapes
Takes as input a TensorDict with the following keys:
“x” - Float[”… max_nodes feature”] - The features of the nodes.
“adjacency” - Float[”… max_nodes max_nodes”] - The adjacency matrix of the graph.
“node_mask” - Bool[”… max_nodes”] - A mask indicating which nodes exist
Methods Summary
__init__
(mlp[, eps, train_eps, ...])Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward
(tensordict)Apply the GIN layer to the input TensorDict.
Reset the parameters of the layer.
Attributes
T_destination
call_super_init
dump_patches
in_keys
The keys of the input TensorDict.
out_keys
The keys of the output TensorDict.
out_keys_source
training
Methods
- __init__(mlp: Module, eps: float = 0.0, train_eps: bool = False, feature_in_key: str | Tuple[str, ...] = 'x', feature_out_key: str | Tuple[str, ...] = 'x', adjacency_key: str | Tuple[str, ...] = 'adjacency', node_mask_key: str | Tuple[str, ...] = 'node_mask', vmap_compatible: bool = False)[source]#
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(tensordict: TensorDictBase) TensorDict [source]#
Apply the GIN layer to the input TensorDict.
- Parameters:
tensordict (TensorDictBase) – The input TensorDict with a dense representation of the graph. It should have a key for the features, adjacency matrix and (optionally) node mask.
- Returns:
out (TensorDict) – The input TensorDict with the GIN layer applied. This includes the updated features.