PPO Basics for LLM Fine-tuning
メニューを表示するにはスワイプしてください
PPO Basics for LLM Fine-tuning
Once you have a reward model, you need an algorithm to update the LLM using its scores. Proximal Policy Optimization (PPO) is the standard choice. It maximizes reward while preventing updates large enough to destabilize the model or collapse its language generation ability.
The Core Idea
In PPO, the LLM is the policy — it maps a prompt to a response. At each step:
- The policy generates a response;
- The reward model scores it;
- PPO updates the policy to increase the probability of high-reward responses.
The key constraint is that updates are clipped — the ratio between the new policy's probability and the old policy's probability is bounded to the range [1−ϵ,1+ϵ], typically with ϵ=0.2. This prevents any single update from changing the model's behavior too dramatically.
LCLIP(θ)=Et[min(rt(θ)A^t,clip(rt(θ),1−ϵ,1+ϵ)A^t)]where rt(θ)=πθold(at∣st)πθ(at∣st) is the probability ratio and A^t is the advantage — how much better the response was than expected.
KL Penalty in Practice
In addition to clipping, LLM fine-tuning with PPO adds a KL divergence penalty between the current policy and the SFT model:
rfinal=rreward−β⋅KL(πθ∥πSFT)This prevents the model from drifting so far toward high reward that it loses fluency or generates degenerate outputs — a failure mode known as reward hacking.
Using TRL for PPO
The trl library implements PPO for LLMs out of the box:
from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead
from transformers import AutoTokenizer
model = AutoModelForCausalLMWithValueHead.from_pretrained("bigscience/bloom-560m")
tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-560m")
ppo_config = PPOConfig(
model_name="bloom-560m",
learning_rate=1.5e-5,
batch_size=16,
kl_penalty="kl",
init_kl_coef=0.2 # β — controls how strongly the KL penalty is applied
)
ppo_trainer = PPOTrainer(ppo_config, model, tokenizer=tokenizer)
# A single PPO step (reward_tensors come from the reward model)
query_tensors = [tokenizer.encode("How do I reset my password?", return_tensors="pt")[0]]
response_tensors = ppo_trainer.generate(query_tensors)
reward_tensors = [torch.tensor(1.8)] # Score from reward model
stats = ppo_trainer.step(query_tensors, response_tensors, reward_tensors)
print(stats)
Run this locally to see the PPO training statistics — pay attention to kl and mean_reward to monitor whether the policy is staying close to the SFT baseline.
フィードバックありがとうございます!
AIに質問する
AIに質問する
何でも質問するか、提案された質問の1つを試してチャットを始めてください