nip.utils.maths.compute_conjugate_gradient_ihvp

nip.utils.maths.compute_conjugate_gradient_ihvp#

nip.utils.maths.compute_conjugate_gradient_ihvp(follower_loss: Tensor, leader_loss: Tensor, follower_params: dict[str, Tensor], leader_params: dict[str, Tensor], num_iterations: int, lr: float) dict[str, Tensor][source]#

Approximate the inverse Hessian vector product with conjugate gradient.