Experiment Settings (nip.experiment_settings
)#
The ExperimentSettings
object contains various settings for the experiment not
relevant to reproducibility.
- class nip.experiment_settings.ExperimentSettings(device: ~torch.device | str | int = device(type='cpu'), run_id: str | None = None, wandb_run: ~wandb.sdk.wandb_run.Run | None = None, silence_wandb: bool = True, base_wandb_run: ~typing.Annotated[~wandb.apis.public.runs.Run | None, <nip.experiment_settings._MarkUnpicklable object at 0x7f21b47b5390>] = None, stat_logger: ~nip.stat_logger.StatLogger | None = <factory>, tqdm_func: callable = <class 'tqdm.std.tqdm'>, logger: ~logging.Logger | ~logging.LoggerAdapter | None = None, profiler: ~torch.profiler.profiler.profile | None = None, ignore_cache: bool = False, num_rollout_samples: int = 10, rollout_sample_period: int = 1000, checkpoint_period: int = 1000, num_dataset_threads: int = 8, num_rollout_workers: int = 4, pin_memory: bool = True, dataset_on_device: bool = False, enable_efficient_attention: bool = False, global_tqdm_step_fn: ~typing.Annotated[callable, <nip.experiment_settings._MarkUnpicklable object at 0x7f21b47b5390>] = <function _default_global_tqdm_step_fn>, pretrained_embeddings_batch_size: int = 256, num_api_generation_timeouts: int = 100, num_api_connection_errors: int = 10, do_not_load_checkpoint: bool = False, test_run: bool = False)[source]#
Instance-specific settings for the experiment.
- Parameters:
device (TorchDevice, default="cpu") – The device to use for training.
run_id (str, optional) – The ID of the current run. This can be used to save and restore the state of the experiment.
wandb_run (wandb.wandb_sdk.wandb_run.Run, optional) – The W&B run to log to, if any.
silence_wandb (bool, default=True) – Whether to suppress W&B output.
base_wandb_run (wandb.apis.public.Run, optional) – The base W&B run, if using. This is an already complete run loaded using the W&B API.
stat_logger (StatLogger, optional) – The logger to use for logging statistics. If not provided, a dummy logger is used, which does nothing.
tqdm_func (Callable, optional) – The tqdm function to use. Defaults to tqdm.
logger (logging.Logger | logging.LoggerAdapter, optional) – The logger to log to. If None, the trainer will create a logger.
profiler (torch.profiler.profile, optional) – The PyTorch profiler being used to profile the training, if any.
ignore_cache (bool, default=False) – If True, the dataset and model cache are ignored and rebuilt.
num_rollout_samples (int, default=10) – The number of rollout samples to collect and save per iteration of RL training. These are useful to visualize the progress of the training.
rollout_sample_period (int, default=1000) – The frequency with which to collect rollout samples. This is the number of iterations of RL training between each collection of rollout samples.
save_models_period (int, default=1000) – The frequency with which to save the models. This is the number of iterations of RL training between each save of the models.
num_dataset_threads (int, default=8) – The number of threads to use for saving the memory-mapped tensordict.
num_rollout_workers (int, default=4) – The number of workers to use for collecting rollout samples, when this is done in parallel. If this is 0, the rollouts are collected in the main process.
pin_memory (bool, default=True) – Whether to pin the memory of the tensors in the dataloader, and move them to the GPU with
non_blocking=True
. This can speed up training. When the device is the CPU this setting doesn’t do anything and is set to False.dataset_on_device (bool, default=False) – Whether store the whole dataset on the device. This can speed up training but requires that the dataset fits on the device. This makes
pin_memory
redundant.enable_efficient_attention (bool, default=False) – Whether to enable the ‘Memory-Efficient Attention’ backend for the scaled dot-product attention. There may be a bug in this implementation which causes NaNs to appear in the backward pass. See pytorch/pytorch#119320 for more information.
global_tqdm_step_fn (Callable, default=lambda: ...) – A function to step the global tqdm progress bar. This is used when there are multiple processes running in parallel and each process needs to update the global progress bar.
pretrained_embeddings_batch_size (int, default=256) – The batch size to use when generating embeddings for the pretrained models.
num_api_generation_timeouts (int, default=100) – The number of timeouts to allow when generating API outputs. If the number of timeouts exceeds this value, the experiment will be stopped.
num_api_connection_errors (int, default=100) – The number of connection errors to allow when generating API outputs. The generation request is retried with exponential back-off with the formual 0.01 * 2 ** num_attempts, so this value should not be higher than around 12. This error type is more general that timeouts, which have their own counter. If the number of connection errors exceeds this value, the experiment will be stopped.
do_not_load_checkpoint (bool, default=False) – If True, the experiment will not load a checkpoint if one exists.
test_run (bool, default=False) – If True, the experiment is run in test mode. This means we do the smallest number of iterations possible and then exit. This is useful for testing that the experiment runs without errors. It doesn’t make sense to use this with wandb_run.