A better training method for reinforcement learning with human feedback


A better training method for reinforcement learning with human feedback

Contrasting training pairs with large reward differences mitigate spurious correlations and improve performance of direct-alignment algorithms by as much as 20%40%.

Machine learning

May 02, 09:00 AMMay 02, 10:23 AM

Reinforcement learning with human feedback (RLHF) is the standard method for aligning large language models (LLMs) with human preferences such as the preferences for nontoxic language and factually accurate responses. Recently, one of the most popular RLHF methods has been direct preference optimization, in which the LLM chooses between two output options, one of which has been labeled as preferred by human annotators.

With direct preference optimization (DPO), however and with other, similar direct-alignment algorithms LLMs run the risk of learning spurious correlations from the data. In toxicity datasets, for instance, its common for the serious, thoughtful responses to be longer than the offensive responses. During RLHF, an LLM could thus learn to prefer longer responses to shorter ones, which may not be preferable in general.

At this years International Conference on Learning Representations (ICLR), we presented a method for limiting such spurious correlations, which we call SeRA, for self-reviewing and alignment. First, after the first round of RLHF on human-annotated data, we use the LLM itself to generate additional training examples. Then we use the LLMs output probabilities to assess the strength of preference for training pairs, keeping only those where the preferred response is strongly preferred.

To evaluate our approach, we compare a model trained using SeRA to three baseline models on four benchmark datasets. For each test input, we compare our models output to that of each of the baselines, and we use an off-the-shelf LLM to choose the better response. The SeRA-trained models win rate in these pairwise comparisons is higher than all three baselines across the board, sometimes by as much as 20% to 40%.

Direct preference optimization

Reinforcement learning is a trial-and-error method in which an agent interacts with the world and, depending on the actions it takes, receives greater or lesser rewards. Over time, the agent attempts to learn a policy that maximizes its cumulative reward.

In classical reinforcement learning, the interaction with the world can be literal: a robot, for instance, might receive a large reward for successfully navigating to a prescribed location and a negative reward for bumping into a wall. In RLHF, however, the reward depends on how well an LLMs output aligns with a paradigm case specified by a human.

With traditional RLHF, the reward is calculated by a separate model, which is also trained on human-annotated data. But this is a time-consuming approach that doesnt scale well. With DPO, theres no need for a second model: the LLM receives the reward if it picks the human-preferred output and not if it doesnt.

The drawback of DPO is that it treats all training pairs equally: the reward is the same whether the preferred output is strongly preferred or only mildly preferred. This increases the chances that the model will learn spurious correlations.

If, for instance, choosing strongly toxic responses incurred a greater penalty than choosing mildly toxic responses, the model could infer that toxicity and not response length was the relevant feature of the training examples. DPO irons out those differences; SeRA reintroduces them.

With SeRa, we first perform conventional DPO, using a dataset of human-annotated example pairs. After this first pass through the data, the LLM has learned something about the types of outputs that humans prefer.

We then use the updated model to generate a new set of training examples. For every generated response pair, we assign each response a preference score, which is based on the updated models probability of generating that response. We then keep only those pairs in which the preferred response scores significantly higher than the non-preferred response.

With SeRa (self-reviewing and alignment), the updated model generates a new response pair (a winner, or <i>y<sub>w</sub></i>, and loser, or <i>y<sub>l</sub></i>) for each sample input (<i>x</i>). Each response receives a preference score, which is based on the updated models probability of generating it. Pairs in which the score of the preferred response is significantly higher than that of the non-preferred response <i>(green)</i> are kept; the others <i>(red)</i> are discarded.<br/><br/>

Using the same metric, we next filter the data in the original, human-annotated dataset. Then we combine filtered samples from the original dataset with filtered samples from our new, generated dataset and perform DPO once again. This process repeats, with the generated samples constituting a larger and larger fraction of the dataset, until model performance converges.

The intuition here is that if a dataset is designed to represent some contrast, but it also contains spurious correlations, then the intended contrast between, say, toxic and non-toxic data will be significantly greater than the unintended contrast between, say, long and short responses.

This assumption held for the four benchmark datasets we used to evaluate our method, and we think that its a plausible assumption for other spurious correlations. But there could be instances in which it doesnt hold, so in applications of the SeRA method, the models convergence behavior should be monitored.

While we used DPO in our experiments, in our paper, we also demonstrate how to generalize our method to other direct-alignment algorithms. Finally, theres some risk that, when using model-generated data to train a model, we could get into a feedback loop where the model overamplifies some aspect of the initial dataset. As a consequence, in each pass through the data, the models reward is based not only on the current iteration but on past iterations as well, to ensure continuity in the characteristic features of the training data.

Acknowledgments: Sravan Bodapati

SeRA

Research areas: Machine learning, Conversational AI

Tags: Large language models (LLMs), Reinforcement learning, Contrastive learning

Read More