Creating a New Trainer#
In this guide, we will walk through the process of creating a new RL trainer.
A trainer takes a scenario instance, consisting of the environment, agents, interaction protocol handler and other components, and is responsible for the whole training process, including things like logging, checkpointing, and evaluation.
A number of base classes are provided, which are designed to allow you to specify just the parts of the training process that are specific to your trainer.
The first decision is whether your trainer will work with TensorDict-based scenarios or pure text scenarios.
TensorDict-based trainers work with scenarios where the agents are locally run neural networks, so we need to pass around PyTorch tensors, which are stored together in TensorDict data structured. RL training is done using the TorchRL library.
Pure text trainers work with scenarios where the agents are text-based models accessed by an API. In this case we need to pass around strings, rather than tensors. These trainers typically call an agent method which performs some training using an API.
See TensorDict or Pure Text Trainer? and TensorDict or Pure Text Scenario? for more information.
It is recommended that you read How an Experiment is Built to understand how the trainer fits into the overall experiment.
Which Parts of this Guide to Read#
Read the Main Steps section to get an overview of the process of creating a new trainer.
Look at the flow chart in the Trainer Base Classes section to decide which base class to subclass.
Read the description of the base class you’ve chosen in the Trainer Base Classes section.
Read subsections of the Available Experiment Components section that are relevant to your base class.
Main Steps#
Here are the main steps to create a new trainer:
Add the name of the trainer to
TrainerType
.(Optional) Create a
SubParameters
subclass innip/parameters/trainers.py
to hold the trainer-specific parameters (see Creating New Parameters).Implement the trainer by subclassing one of the base classes. See Trainer Base Classes below. Register the trainer with the
register_trainer
decorator.
Trainer Base Classes#
To choose which base class to subclass, either follow the following flowchart or directly read the descriptions under each heading below.
flowchart TD data_structure_type{{"`What type of data structure will you use?`"}} data_structure_type -->|"`TensorDict`"| tensordict_novelty{{"`Is it enough to specify a new loss function and use the default train loop?`"}} data_structure_type -->|pure text| pure_text_class[PureTextRlTrainer] data_structure_type -->|other| trainer_class[Trainer] tensordict_novelty --->|yes| rl_trainer[TensorDictRlTrainer] tensordict_novelty --->|no| tensordict_trainer[TensorDictTrainer]
TensorDictRlTrainer
#
Use this class if your trainer works with TensorDict data structures and you’re happy to use a standard TorchRL training loop, but you need to specify a new loss function.
To implement a new trainer, subclass this class and define the
_get_loss_module_and_gae
method.
This method should return a loss module (an instance of a subclass of Objective
) and, optionally, a Generalised Advantage Estimation
(GAE) module (if you’re using GAE). The GAE is typically constructed from the loss
module as follows.
from torchrl.objectives import ValueEstimators
...
loss_module.make_value_estimator(ValueEstimators.GAE, **additional_parameters)
gae = loss_module.value_estimator
The train and test loops are implemented in the _run_train_loop
and
_run_test_loop
methods,
respectively. You can customise these methods as needed. Look at the source code for
each method to see how to do this.
TensorDictTrainer
#
Use this class if your trainer works with TensorDict data structures and you need to implement a custom training loop.
You’ll need to implement the train
method, which should perform
the following steps, as appropriate:
Set the seed
Run the training loop, logging the results
Run the test loop, logging the results
Save the models
It is recommended that you define separate _run_train_loop
and _run_test_loop
methods, decorating with attach_progress_bar
as follows. This will allow the run
script to compute the total number of training steps in the whole process, and also
allow you to customise the progress bar.
from nip.trainers.trainer_base import attach_progress_bar
...
class MyTrainer(TensorDictTrainer):
...
# The ``attach_progress_bar`` takes a function that returns the total number of
# iterations for this phase of training.
@attach_progress_bar(lambda self: self.hyper_params.rl.num_iterations)
def _run_train_loop(self, iteration_context: IterationContext):
# Add a description to the progress bar
iteration_context.set_description("Training")
...
@attach_progress_bar(lambda self: self.hyper_params.rl.num_test_iterations)
def _run_test_loop(self, iteration_context: IterationContext):
# Add a description to the progress bar
iteration_context.set_description("Testing")
...
It is also recommended that you call these methods in an
ExitStack
context manager built using
_build_test_context
and
_build_train_context
, as follows. This
will ensure we make the appropriate PyTorch configuration.
from contextlib import ExitStack
...
class MyTrainer(TensorDictTrainer):
...
def train(self):
...
# Run the training loop with the appropriate context managers
with ExitStack() as stack:
self._build_train_context(stack)
self._run_train_loop()
# Run the test loop with the appropriate context managers
with ExitStack() as stack:
self._build_test_context(stack)
self._run_test_loop()
...
Most trainers will be reinforcement learning trainers, but if you’re using this class it
may be because you’re doing something other than reinforcement learning. So that the
factory
knows which parts of the agents it should build for the
trainer, you should define the trainer_type
class attribute of your
TensorDictTrainer
subclass.
Currently, this can take the following values.
"rl"
(default): A reinforcement learning trainer. This means that the factory will build policy and value heads for the agents."solo_agent"
: A trainer that trains a single agent to solve the task using supervised learning. The factory will build only the solo agent head for the agents.
If you want to do something different, you can define a new value for this attribute,
and you many need to modify the factory
to handle this. See
How an Experiment is Built.
PureTextRlTrainer
#
Use this class if your trainer works with pure text data structures.
All subclasses must define at least the _stage_create_fine_tune_jobs
method
(see below).
Rather than using TensorDict objects, these trainers use the custom
NestedArrayDict
data structure.
This is similar to TensorDict, in that it is a nested dictionary, but it
contains Numpy string arrays rather than PyTorch tensors.
The PureTextRlTrainer
class
implements a training loop consisting of multiple stages. The experiment state is saved
after each stage, and the experiment can be resumed from any stage. The stages are as
follows.
Sample rollouts from the environment. You can customise this stage by overriding the
_sample_rollouts
method.Log statistics for the sampled rollouts. Logging can be customised by overriding the
_get_log_stats
method. This method returns a dictionary of statistics to log, and when overriding, it is recommended to call the superclass method and update the dictionary.Test the agents. This stage runs the test loop if specified by the
hyper_params.text_rl.test_scheme
hyper-parameter. Any customisation to the_sample_rollouts
method will also affect the test loop.Create fine-tune jobs for each agent. This stage creates API jobs for each group of agents which share a model (see Agents (Pure-Text-Based Trainers)). This stage must be implemented by the subclass, which is done by defining the
_stage_create_fine_tune_jobs
method. This method takes as input the rollouts sampled in the first stage and calls an appropriate method of eachPureTextSharedModelGroup
agent group (e.g.create_supervised_fine_tune_job
). This is the only method that must be implemented by the subclass.Wait for all fine-tune jobs to complete.
Trainer
#
This is the base class for all trainers, and can be subclassed if you want to do
something more specialised, which doesn’t fit into the other categories. This probably
means you’re using a custom data structure. You’ll need to implement the train
method, which should perform the following
steps, as appropriate:
Set the seed
Run the training loop, logging the results
Run the test loop, logging the results
Save the models
Available Experiment Components#
This section details the components which are available to trainers. All trainers are initialised with the following objects:
A
HyperParameters
instance, which contains the hyper-parameters specifying the experiment.A
ScenarioInstance
instance, which a dataclass holding all the components of the experiment.A
ExperimentSettings
instance, which contains various experiment settings not relevant to reproducibility (e.g. the GPU device number and whether to use Weights & Biases).
The following components are derived from these objects. For more information on these components and how they are built, see How an Experiment is Built.
Some components are only available to specific base classes, which is indicated in the description.
The following assumes we are working in a method of a trainer class, so self
refers to the trainer instance.
Datasets#
The train and test datasets are instances of Dataset
(or a subclass specific to the current
scenario), and are available as self.scenario_instance.train_dataset
and
self.scenario_instance.test_dataset
.
If your trainer is TensorDict-based (i.e. subclasses TensorDictRlTrainer
or TensorDictTrainer
), the datasets are instances of
TensorDictDataset
. If your trainer
is pure-text-based (i.e. subclasses PureTextRlTrainer
), the datasets are instances of
NestedArrayDictDataset
.
Environments#
Reinforcement learning environments are instances of the Environment
class. They specify things like the action
and state specs, and handle updating the environment state and rewards based on actions
made by the agents.
There is an environment instance for each dataset (because observations come by sampling
from the dataset), and they are available as self.scenario_instance.train_enviroment
and self.scenario_instance.test_enviroment
.
The key methods and properties of the environment are:
|
Reset the environment. |
|
Perform a step in the environment. |
|
The specification for the observation keys. |
|
The specification for the action keys. |
|
The specification for the state keys. |
|
The specification for the agent reward keys. |
|
The specification for the done keys (done and terminated). |
|
The number of steps per batched environment in each iteration. |
|
The number of frames to sample per training iteration. |
|
The number of batched environments. |
|
The batch size of the environment. |
Depending on the type of trainer, the environments may be instances of
TensorDictEnvironment
or
PureTextEnvironment
.
Agents (TensorDict-Based Trainers)#
In TensorDict-based trainers, each agent is composed of one or more bodies and one or more heads.
Bodies (
AgentBody
) are responsible for processing the environment observations and producing the agent’s internal state. Agents typically have one body, shared between all heads. However, ifhyper_params.rl.shared_body
isFalse
, we use a separate body for each head.Heads (
AgentHead
) are responsible for producing the agent’s actions and values.Policy heads (
AgentPolicyHead
) produce probability distributions over actions.Value heads (
AgentValueHead
) produce value estimates.Solo agent heads (
SoloAgentHead
) output predictions for the true labels of the data, and are used when the agents are trained in isolation using supervised learning (rather than reinforcement learning).
All agent parts are subclasses of
AgentPart
.The parts of TensorDict-based agents also subclass the
TensorDictAgentPartMixin
mixin class.
All agent parts are collected together in an Agent
dataclass. These are available in the
self.scenario_instance.agents
dictionary, indexed by agent name.
Agent parts of the same type are combined together across agents. This allows dealing
with ‘combined agents’ which can be treated as a single actor in reinforcement learning,
and working easily with the TorchRL library. These
combined parts are stored in self.scenario_instance
, and are available as follows.
Agents (Pure-Text-Based Trainers)#
In pure-text-based trainers, agents are not composed of parts, but are whole entities
which can be sampled and trained using an API. The agents are instances of the
PureTextWholeAgent
class. To
maintain compatibility with TensorDict-based agents, the PureTextWholeAgent
is also stored in an Agent
dataclass, in the whole
attribute. The dataclass
instances are stored in the self.scenario_instance.agents
dictionary, indexed by
agent name.
Some agents may share an underlying model, where the system prompt is used to condition
the underlying model to act as each distinct agent. To facilitate training these agents,
we use the ‘shared model group’ abstraction. A PureTextSharedModelGroup
contains a collection of
PureTextWholeAgent
instances which
share an underlying model. Trainers therefore only need to call each share model group’s fine-tune methods, rather than having to worry about which agents share the same model.
The following methods allow fine-tuning the underlying model of a shared model group.
|
Create a supervised fine-tune job for the agent group given sampled rollouts. |
Create a DPO fine-tune job for the agent group given sampled timesteps. |
|
Get the status of the fine-tune job. |
|
Get a string representation of the error for the fine-tune job. |
|
Switch to the next model after fine-tuning. |
PureTextSharedModelGroup
instances maintain a state in the state
attribute, which can be used to store and
retrieve checkpoints. This state is a PureTextSharedModelGroupState
dataclass. The state of each
agent is stored in trainer experiment state. See Experiment State below
for more information.
Protocol Handler#
The protocol handler implements the interaction protocol for the experiment. It
is an instance of the ProtocolHandler
class, and is available as
self.protocol_handler
.
The protocol handler is mainly used by the environment to implement the step function,
but the following properties can be useful. For a full list of properties and methods,
see the ProtocolHandler
class.
|
The names of the agents in the protocol. |
|
The names of the provers in the protocol. |
|
The names of the verifiers in the protocol. |
|
The maximum number of rounds in the protocol. |
|
The minimum number of rounds in the protocol. |
|
The names of the message channels in the protocol. |
|
A specification of which agents can see which message channels. |
Experiment State#
The experiment state is currently only used by pure text trainers (which subclass
PureTextRlTrainer
). It is a
dataclass specified by the Trainer.State
child class (or a subclass specific to the
current scenario), and is available as self.state
.
For pure-text trainers, the PureTextRlTrainer.State
dataclass is specified as follows.
- class nip.trainers.rl_pure_text_base.PureTextRlTrainer.State(iteration: int = 0, agents: dict[str, ~nip.scenario_base.agents.AgentState] = <factory>, train_loop_stage: ~typing.Literal['sample_rollouts', 'log_stats', 'create_fine_tune_jobs', 'await_fine_tune_jobs', 'test_during_training', 'test', 'done'] = 'sample_rollouts', shared_model_groups: dict[str, ~nip.scenario_base.agents.PureTextSharedModelGroupState] = <factory>, base_run_state_artifact_version: int = 0)[source]#
The state of the experiment.
- Parameters:
iteration (int) – The current iteration number.
agents (dict[str, AgentCheckpoint]) – The checkpoints of the agents.
train_loop_stage (str) –
The current stage of the training loop. One of:
”sample_rollouts”: Sample rollouts from the training environment.
”log_stats”: Log the statistics of the sampled rollouts.
”create_fine_tune_jobs”: Create fine-tune jobs for each shared agent group.
”await_fine_tune_jobs”: Await the completion of the fine-tune jobs.
”test_during_training”: Run the test loop during training.
”test”: Run the test loop after training.
”done”: The training is complete.
shared_model_groups (dict[str, PureTextSharedModelGroupState]) – The state of each shared model group.
base_run_state_artifact_version (int) – When rerunning tests, we step through the states of the base run in order. This is the version of the base run state artifact that we’re on.
The experiment state is saved after each stage of the training loop. If the experiment
stopped partway through training, it can be resumed later by setting the run_id
argument of the run
function to the ID of the previous run.
This will automatically load the experiment state from the previous run, either locally
or from Weights & Biases.