nip.utils.maths.compute_sos_update

nip.utils.maths.compute_sos_update#

nip.utils.maths.compute_sos_update(simultaneous_grad: dict[str, Tensor], hessian_grad_product: dict[str, Tensor], opponent_shaping: dict[str, Tensor], scaling_factor: float, threshold_factor: float) dict[str, Tensor][source]#

Compute the update for the Stable Opponent Shaping (SOS) algorithm.

See Algorithm 1 in Letcher et al. [LFB+19]. Effectively, this update interpolates between the LOLA update (Foerster et al. [FCAS+18]) and the LookAhead update (Zhang and Lesser [ZL10]) by computing a coefficient p between 0 and 1 where p = 0 corresponds to LookAhead and p = 1 corresponds to LOLA.

Parameters:
  • simultaneous_grad (dict[str, Tensor]) – The vanilla individual updates. Named \(\xi\) in Letcher et al.

  • hessian_grad_product (dict[str, Tensor]) – The product of the anti-diagonal of the Hessian matrix with the vector \(\xi\).

  • opponent_shaping (dict[str, Tensor]) – The opponent shaping term for each parameter. Named \(\chi\) in Letcher et al.

  • scaling_factor (float) – A scaling factor (between 0 and 1). Named \(a\) in Letcher et al.

  • threshold_factor (float) – A threshold value (between 0 and 1). Named \(b\) in Letcher et al.

Returns:

update (dict[str, Tensor]) – The update to be made to the parameters.