nip.utils.maths.compute_nystrom_ihvp#
- nip.utils.maths.compute_nystrom_ihvp(follower_loss: Tensor, leader_loss: Tensor, follower_params: dict[str, Tensor | Parameter], leader_params: dict[str, Tensor | Parameter], rank: int = 5, rho: float = 0.1, retain_graph: bool = True, generator: Generator | None = None) dict[str, Tensor] [source]#
Approximate the inverse Hessian vector product with the Nystrom method.
This function approximates the inverse Hessian of the follower’s loss with respect to the its parameters and multiplies it by the gradients of the leader’s loss with respect to the follower’s parameters. See Hataya and Yamada [HY23] for more details.
The function computes:
\[(H_k + \rho I) \frac{\partial g}{\partial \theta}\]where:
\(f\) is the follower’s loss
\(g\) is the leader’s loss
\(\theta\) are the follower’s parameters
\(\phi\) are the leader’s parameters
\(H_k\) is the \(k\)-rank approximation of the Hessian of \(f\) with respect to \(\theta\)
- Parameters:
follower_loss (Tensor) – Follower objective
leader_loss (Tensor) – Leader objective
follower_params (dict[str, Tensor | Parameter]) – The parameters of the follower agent
leader_params (dict[str, Tensor | Parameter]) – The parameters of the leader agent
rank (int, default=5) – Rank of low-rank approximation
rho (float, default=0.1) – Additive constant to improve numerical stability
retain_graph (bool, default=True) – Whether to retain the computation graph for use in computing higher-order derivatives.
generator (torch.Generator, optional) – The PyTorch random number generator, used for sampling the rank columns of the Hessian matrix.
- Returns:
ihvp (dict[str, Tensor]) – A dictionary where the keys are the follower parameter names and the values are the inverse Hessian-vector product, i.e. the inverse Hessian multiplied by the leader gradients
Notes
Adapted from moskomule/hypergrad