Vanishing Gradients in Reinforcement Finetuning of Language Models

Pretrained language models are commonly adapted to comply with human intent and downstream tasks via finetuning. The finetuning process involves supervised finetuning (SFT), using labeled samples, and/or reinforcement learning based fine-tuning (RFT) via policy gradient methods, using a (possibly learned) reward function. This work highlights an overlooked optimization hurdle in RFT: we prove that the expected gradient for an input sample (i.e. prompt) vanishes if its reward standard deviation under the model is low, regardless of whether the reward mean is near-optimal or not. We then…Apple Machine Learning Research