nip.utils.torch

nip.utils.torch#

Handy PyTorch classes and utilities, including modules.

Functions

apply_orthogonal_initialisation(module, gain)

Apply orthogonal initialisation to a module's weights and set the biases to 0.

flatten_batch_dims(x, num_batch_dims)

Return a new view of a tensor with the batch dimensions flattened.

Classes

BatchNorm1dSimulateBatchDims(num_features[, ...])

Batch normalization layer with arbitrary batch dimensions.

CatGraphPairDim(cat_dim[, pair_dim])

Concatenate the two node sets for each graph pair.

Conv2dSimulateBatchDims(in_channels, ...[, ...])

2D convolutional layer with arbitrary batch dimensions.

DummyOptimizer(*args, **kwargs)

A dummy optimizer which does nothing.

FastForwardableBatchSampler(sampler, ...[, ...])

A batch sampler which can skip an initial number of items.

GIN(*args, **kwargs)

A graph isomorphism network (GIN) layer [XHLJ18].

GlobalMaxPool([dim, keepdim])

Global max pooling layer over a dimension.

MaxPool2dSimulateBatchDims(kernel_size[, ...])

2D max pool layer with arbitrary batch dimensions.

NormalizeOneHotMessageHistory(*args, **kwargs)

Normalize the history of one-hot message exchanges.

OneHot([num_classes])

One-hot encode a tensor.

PairInvariantizer([pair_dim])

Transform the input to be invariant to the order of the graphs in a pair.

PairedGaussianNoise(sigma[, pair_dim, ...])

Add Gaussian noise copied across the graph pair dimension.

ParallelTensorDictModule(*args, **kwargs)

Apply a module to each key of a TensorDict.

Print([name, mode, transform])

Print information about an input tensor.

ResNetBasicBlockSimulateBatchDims(inplanes, ...)

ResNet basic block with arbitrary batch dimensions.

ResNetBottleneckBlockSimulateBatchDims(...)

ResNet bottleneck block with arbitrary batch dimensions.

SimulateBatchDimsMixin()

A mixin for simulating multiple batch dimensions.

Squeeze([dim])

Squeeze a dimension.

TensorDictCat(*args, **kwargs)

Concatenate the keys of a TensorDict.

TensorDictCloneKeys(*args, **kwargs)

Clone the keys of a TensorDict.

TensorDictPrint(*args, **kwargs)

Print information about an input tensordict.

UpsampleSimulateBatchDims([size, ...])

Upsample layer with arbitrary batch dimensions.