nip.utils.maths

nip.utils.maths#

Utilities for useful mathematical operations.

Functions

_zero_grad(params)

Zero the gradients of the parameters in a dictionary.

aggregate_mean_grouped_by_class(values, classes)

Compute the mean of values grouped by class.

compute_conjugate_gradient_ihvp(...)

Approximate the inverse Hessian vector product with conjugate gradient.

compute_neumann_ihvp(follower_loss, ...)

Approximate the inverse Hessian vector product with the Neumann method.

compute_nystrom_ihvp(follower_loss, ...[, ...])

Approximate the inverse Hessian vector product with the Nystrom method.

compute_sos_update(simultaneous_grad, ...)

Compute the update for the Stable Opponent Shaping (SOS) algorithm.

dict_dot_product(dict_1, dict_2)

Calculate the dot product between two dictionaries of tensors.

dict_scalar_multiple(dictionary, scalar)

Calculate a scalar multiple of a dictionary of tensors.

dict_sum(dict_1, dict_2)

Calculate the sum of two dictionaries of tensors, element-wise.

inverse_hessian_vector_product(...[, rank, ...])

Compute the inverse Hessian-vector product using specified approximation method.

is_broadcastable(shape_1, shape_2)

Check if two shapes are broadcastable.

logit_entropy(logits)

Compute the entropy of a set of logits.

logit_or(logits[, dim])

Compute the logit of the OR of n events given their logits.

logit_or_dual(a, b)

Compute the logit OR operation for two input tensors with the log-sum-exp trick.

mean_episode_reward(reward, done_mask)

Compute the mean total episode reward for a batch of concatenated episodes.

mean_for_unique_keys(data, key[, axis])

Compute the mean of values grouped by unique keys.

minstd_generate_pseudo_random_sequence(seed, ...)

Generate a pseudo-random sequence of numbers using the MINSTD algorithm.

set_seed(seed)

Set the seed in Python, NumPy, and PyTorch.