nip.parameters.trainers.SpgParameters#

class nip.parameters.trainers.SpgParameters(variant: Literal['spg', 'pspg', 'lola', 'pola', 'sos', 'psos'] = 'psos', stackelberg_sequence: tuple[tuple[str, ...]] | None = None, additional_lola_term: bool = True, sos_scaling_factor: float = 0.5, sos_threshold_factor: float = 0.1, ihvp_variant: Literal['conj_grad', 'neumann', 'nystrom'] = 'nystrom', ihvp_num_iterations: int = 5, ihvp_rank: int = 5, ihvp_rho: float = 0.1)[source]#

Additional parameters for SPG [FCR20] and its variants.

Parameters:
  • variant (SpgVariantType) – The variant of SPG to use.

  • stackelberg_sequence (tuple[tuple[str, ...]], optional) – The sequence of agents to use in the Stackelberg game. The leaders first then their respective followers, and so forth. If None, the sequence is determined automatically based on the protocol.

  • additional_lola_term (bool) – Whether to add an additional term to the SPG loss to make it equivalent to the later version of LOLA (first introduced implicitly in LOLA-DICE) as opposed to the original version.

  • sos_scaling_factor (float) – The SOS scaling factor (between 0 and 1), used with Stable Opponent Shaping.

  • sos_threshold_factor (float) – The SOS threshold factor (between 0 and 1), used with Stable Opponent Shaping.

  • ihvp_variant (IhvpVariantType) – The variant of IHVP to use.

  • ihvp_num_iterations (int) – The number of iterations to use in the IHVP approximation.

  • ihvp_rank (int) – The rank of the approximation to use in the IHVP approximation.

  • ihvp_rho (float) – The damping factor to use in the IHVP approximation.

Methods Summary

__eq__(other)

Return self==value.

__init__([variant, stackelberg_sequence, ...])

__post_init__()

__repr__()

Return repr(self).

_get_param_class_from_dict(param_dict)

Try to get the parameter class from a dictionary of serialised parameters.

construct_test_params()

Construct a set of basic parameters for testing.

from_dict(params_dict[, ignore_extra_keys])

Create a parameters object from a dictionary.

get(address)

Get a value from the parameters object using a dot-separated address.

to_dict()

Convert the parameters object to a dictionary.

Attributes

additional_lola_term

ihvp_num_iterations

ihvp_rank

ihvp_rho

ihvp_variant

sos_scaling_factor

sos_threshold_factor

stackelberg_sequence

variant

Methods

__eq__(other)#

Return self==value.

__init__(variant: Literal['spg', 'pspg', 'lola', 'pola', 'sos', 'psos'] = 'psos', stackelberg_sequence: tuple[tuple[str, ...]] | None = None, additional_lola_term: bool = True, sos_scaling_factor: float = 0.5, sos_threshold_factor: float = 0.1, ihvp_variant: Literal['conj_grad', 'neumann', 'nystrom'] = 'nystrom', ihvp_num_iterations: int = 5, ihvp_rank: int = 5, ihvp_rho: float = 0.1) None#
__post_init__()[source]#
__repr__()#

Return repr(self).

classmethod _get_param_class_from_dict(param_dict: dict) type[ParameterValue] | None[source]#

Try to get the parameter class from a dictionary of serialised parameters.

Parameters:

param_dict (dict) – A dictionary of parameters, which may have come from a to_dict method. This dictionary may contain a _type key, which is used to determine the class of the parameter.

Returns:

param_class (type[ParameterValue] | None) – The class of the parameter, if it can be determined.

Raises:

ValueError – If the class specified in the dictionary is not a valid parameter class.

classmethod construct_test_params() BaseHyperParameters[source]#

Construct a set of basic parameters for testing.

classmethod from_dict(params_dict: dict, ignore_extra_keys: bool = False) BaseHyperParameters[source]#

Create a parameters object from a dictionary.

Parameters:
  • params_dict (dict) – A dictionary of the parameters.

  • ignore_extra_keys (bool, default=False) – If True, ignore keys in the dictionary that do not correspond to fields in the parameters object.

Returns:

hyper_params (BaseParameters) – The parameters object.

get(address: str) Any[source]#

Get a value from the parameters object using a dot-separated address.

Parameters:

address (str) – The path to the value in the parameters object, separated by dots.

Returns:

value (Any) – The value at the address.

Raises:

KeyError – If the address does not exist.

to_dict() dict[source]#

Convert the parameters object to a dictionary.

Turns enums into strings, and sub-parameters into dictionaries. Includes the is_random parameter if it exists.

Returns:

params_dict (dict) – A dictionary of the parameters.