nip.utils.maths.compute_sos_update#
- nip.utils.maths.compute_sos_update(simultaneous_grad: dict[str, Tensor], hessian_grad_product: dict[str, Tensor], opponent_shaping: dict[str, Tensor], scaling_factor: float, threshold_factor: float) dict[str, Tensor] [source]#
Compute the update for the Stable Opponent Shaping (SOS) algorithm.
See Algorithm 1 in Letcher et al. [LFB+19]. Effectively, this update interpolates between the LOLA update (Foerster et al. [FCAS+18]) and the LookAhead update (Zhang and Lesser [ZL10]) by computing a coefficient p between 0 and 1 where p = 0 corresponds to LookAhead and p = 1 corresponds to LOLA.
- Parameters:
simultaneous_grad (dict[str, Tensor]) – The vanilla individual updates. Named \(\xi\) in Letcher et al.
hessian_grad_product (dict[str, Tensor]) – The product of the anti-diagonal of the Hessian matrix with the vector \(\xi\).
opponent_shaping (dict[str, Tensor]) – The opponent shaping term for each parameter. Named \(\chi\) in Letcher et al.
scaling_factor (float) – A scaling factor (between 0 and 1). Named \(a\) in Letcher et al.
threshold_factor (float) – A threshold value (between 0 and 1). Named \(b\) in Letcher et al.
- Returns:
update (dict[str, Tensor]) – The update to be made to the parameters.