nip.utils.maths.compute_nystrom_ihvp

nip.utils.maths.compute_nystrom_ihvp#

nip.utils.maths.compute_nystrom_ihvp(follower_loss: Tensor, leader_loss: Tensor, follower_params: dict[str, Tensor | Parameter], leader_params: dict[str, Tensor | Parameter], rank: int = 5, rho: float = 0.1, retain_graph: bool = True, generator: Generator | None = None) dict[str, Tensor][source]#

Approximate the inverse Hessian vector product with the Nystrom method.

This function approximates the inverse Hessian of the follower’s loss with respect to the its parameters and multiplies it by the gradients of the leader’s loss with respect to the follower’s parameters. See Hataya and Yamada [HY23] for more details.

The function computes:

\[(H_k + \rho I) \frac{\partial g}{\partial \theta}\]

where:

  • \(f\) is the follower’s loss

  • \(g\) is the leader’s loss

  • \(\theta\) are the follower’s parameters

  • \(\phi\) are the leader’s parameters

  • \(H_k\) is the \(k\)-rank approximation of the Hessian of \(f\) with respect to \(\theta\)

Parameters:
  • follower_loss (Tensor) – Follower objective

  • leader_loss (Tensor) – Leader objective

  • follower_params (dict[str, Tensor | Parameter]) – The parameters of the follower agent

  • leader_params (dict[str, Tensor | Parameter]) – The parameters of the leader agent

  • rank (int, default=5) – Rank of low-rank approximation

  • rho (float, default=0.1) – Additive constant to improve numerical stability

  • retain_graph (bool, default=True) – Whether to retain the computation graph for use in computing higher-order derivatives.

  • generator (torch.Generator, optional) – The PyTorch random number generator, used for sampling the rank columns of the Hessian matrix.

Returns:

ihvp (dict[str, Tensor]) – A dictionary where the keys are the follower parameter names and the values are the inverse Hessian-vector product, i.e. the inverse Hessian multiplied by the leader gradients

Notes

Adapted from moskomule/hypergrad