nip.utils.maths.inverse_hessian_vector_product

nip.utils.maths.inverse_hessian_vector_product#

nip.utils.maths.inverse_hessian_vector_product(follower_loss: Tensor, leader_loss: Tensor, follower_params: dict[str, Tensor | Parameter], leader_params: dict[str, Tensor | Parameter], variant: Literal['conj_grad', 'neumann', 'nystrom'], num_iterations: int, rank: int = 5, rho: float = 0.1, retain_graph: bool = True, generator: Generator | None = None) dict[str, Tensor][source]#

Compute the inverse Hessian-vector product using specified approximation method.

Note that this method zeros the gradients of the leader and follower parameters.

Parameters:
  • follower_loss (Tensor) – The follower loss.

  • leader_loss (Tensor) – The leader loss.

  • follower_params (dict[str, Tensor | Parameter]) – The parameters of the follower.

  • leader_params (dict[str, Tensor | Parameter]) – The parameters of the leader.

  • variant (IhvpVariantType) – The approximation method for computing the IHVP.

  • num_iterations (int) – The number of iterations for the conujugate gradient or Neumann approximation methods.

  • rank (int, default=5) – The rank parameter for the Nystrom approximation method.

  • rho (float, default=0.1) – The rho parameter for the Nystrom approximation method.

  • retain_graph (bool, default=True) – Whether to retain the computation graph for use in computing higher-order derivatives.

  • generator (torch.Generator, optional) – The PyTorch random number generator.

Returns:

ihvp (dict[str, Tensor]) – The computed IHVP.