nip.utils.torch.GIN#
- class nip.utils.torch.GIN(*args, **kwargs)[source]#
A graph isomorphism network (GIN) layer [XHLJ18].
This is a message-passing layer that aggregates the features of the neighbours as follows:
\[x_i' = \text{MLP}((1 + \epsilon) x_i + \sum_{j \in \mathcal{N}(i)} x_j)\]where \(x_i\) is the feature vector of node \(i\), \(\mathcal{N}(i)\) is the set of neighbours of node \(i\), and \(\epsilon\) is a (possibly learnable) parameter.
The difference between this implementation and the one in PyTorch Geometric is that this one takes as input a TensorDict with dense representations of the graphs and features.
- Parameters:
mlp (nn.Module) – The MLP to apply to the aggregated features.
eps (float, default=0.0) – The initial value of \(\epsilon\).
train_eps (bool, default=False) – Whether to train \(\epsilon\) or keep it fixed.
feature_in_key (NestedKey, default="x") – The key of the input features in the input TensorDict.
feature_out_key (NestedKey, default="x") – The key of the output features in the output TensorDict.
adjacency_key (NestedKey, default="adjacency") – The key of the adjacency matrix in the input TensorDict.
node_mask_key (NestedKey, default="node_mask") – The key of the node mask in the input TensorDict.
vmap_compatible (bool, default=False) – Whether the module is compatible with
vmapor not. IfTrue, the node mask is only applied after the MLP, which is less efficient but allows for the use ofvmap.
Shapes
Takes as input a TensorDict with the following keys:
“x” - Float[”… max_nodes feature”] - The features of the nodes.
“adjacency” - Float[”… max_nodes max_nodes”] - The adjacency matrix of the graph.
“node_mask” - Bool[”… max_nodes”] - A mask indicating which nodes exist
Methods Summary
__init__(mlp[, eps, train_eps, ...])Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(tensordict)Apply the GIN layer to the input TensorDict.
Reset the parameters of the layer.
Attributes
T_destinationcall_super_initdump_patchesin_keysThe keys of the input TensorDict.
out_keysThe keys of the output TensorDict.
out_keys_sourcetrainingMethods
- __init__(mlp: Module, eps: float = 0.0, train_eps: bool = False, feature_in_key: str | tuple[str | tuple[NestedKey, ...], ...] = 'x', feature_out_key: str | tuple[str | tuple[NestedKey, ...], ...] = 'x', adjacency_key: str | tuple[str | tuple[NestedKey, ...], ...] = 'adjacency', node_mask_key: str | tuple[str | tuple[NestedKey, ...], ...] = 'node_mask', vmap_compatible: bool = False)[source]#
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(tensordict: TensorDictBase) TensorDict[source]#
Apply the GIN layer to the input TensorDict.
- Parameters:
tensordict (TensorDictBase) – The input TensorDict with a dense representation of the graph. It should have a key for the features, adjacency matrix and (optionally) node mask.
- Returns:
out (TensorDict) – The input TensorDict with the GIN layer applied. This includes the updated features.