import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import GPT2LMHeadModel, GPT2Tokenizer, GPT2Config
import numpy as np
from typing import Dict, List, Tuple
import warnings
warnings.filterwarnings("ignore")
class PPOTrainer:
"""
基于真实GPT2模型的PPO训练器
"""
def __init__(self, model_name="gpt2", clip_ratio=0.2, vf_coef=0.5, ent_coef=0.01):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {self.device}")
# 加载模型和分词器
self.tokenizer = GPT2Tokenizer.from_pretrained(model_name)
self.tokenizer.pad_token = self.tokenizer.eos_token
# 策略网络 (生成模型)
self.policy_model = GPT2LMHeadModel.from_pretrained(model_name).to(self.device)
# 参考模型 (固定的旧策略)
self.ref_model = GPT2LMHeadModel.from_pretrained(model_name).to(self.device)
for param in self.ref_model.parameters():
param.requires_grad = False
# 价值网络 (简单的线性层加在GPT2上)
self.value_head = nn.Linear(self.policy_model.config.hidden_size, 1).to(self.device)
# PPO参数
self.clip_ratio = clip_ratio
self.vf_coef = vf_coef
self.ent_coef = ent_coef
print(f"模型加载完成: {model_name}")
print(f"模型参数量: {sum(p.numel() for p in self.policy_model.parameters()):,}")
def tokenize_texts(self, texts: List[str], max_length: int = 64) -> Dict[str, torch.Tensor]:
"""文本分词"""
encoded = self.tokenizer(
texts,
return_tensors="pt",
padding=True,
truncation=True,
max_length=max_length
)
return {k: v.to(self.device) for k, v in encoded.items()}
def get_model_outputs(self, input_ids: torch.Tensor, attention_mask: torch.Tensor,
model: GPT2LMHeadModel) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
获取模型输出: logits, 对数概率, 熵
"""
with torch.no_grad() if model == self.ref_model else torch.enable_grad():
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
logits = outputs.logits[:, :-1] # 去掉最后一个token的logits
# 计算对数概率
log_probs = F.log_softmax(logits, dim=-1)
# 获取实际选择的token的对数概率
target_tokens = input_ids[:, 1:] # 去掉第一个token (通常是BOS)
selected_log_probs = torch.gather(log_probs, -1, target_tokens.unsqueeze(-1)).squeeze(-1)
# 计算熵
probs = F.softmax(logits, dim=-1)
entropy = -(probs * log_probs).sum(dim=-1)
return logits, selected_log_probs, entropy
def get_value_estimates(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
"""获取价值函数估计"""
outputs = self.policy_model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
hidden_states = outputs.hidden_states[-1] # 最后一层隐藏状态
values = self.value_head(hidden_states)[:, :-1].squeeze(-1) # 去掉最后一个位置
return values
def compute_advantages_and_returns(self, rewards: torch.Tensor, values: torch.Tensor,
attention_mask: torch.Tensor, gamma: float = 0.99,
lam: float = 0.95) -> Tuple[torch.Tensor, torch.Tensor]:
"""
使用GAE计算优势函数和回报
"""
batch_size, seq_len = rewards.shape
advantages = torch.zeros_like(rewards)
returns = torch.zeros_like(rewards)
# 从后往前计算
gae = 0
for t in reversed(range(seq_len)):
mask = attention_mask[:, t + 1] if t + 1 < seq_len else torch.zeros(batch_size).to(self.device)
next_value = values[:, t + 1] if t + 1 < seq_len else torch.zeros(batch_size).to(self.device)
delta = rewards[:, t] + gamma * next_value * mask - values[:, t]
gae = delta + gamma * lam * mask * gae
advantages[:, t] = gae
returns[:, t] = advantages[:, t] + values[:, t]
return advantages, returns
def compute_ppo_loss(self, input_ids: torch.Tensor, attention_mask: torch.Tensor,
old_log_probs: torch.Tensor, rewards: torch.Tensor) -> Dict[str, torch.Tensor]:
"""
计算PPO损失
"""
# 获取当前策略的输出
policy_logits, new_log_probs, entropy = self.get_model_outputs(
input_ids, attention_mask, self.policy_model
)
# 获取价值估计
values = self.get_value_estimates(input_ids, attention_mask)
# 计算优势函数和回报
advantages, returns = self.compute_advantages_and_returns(
rewards, values, attention_mask[:, :-1] # 调整attention_mask尺寸
)
# 标准化优势函数
mask = attention_mask[:, 1:].bool() # 有效位置的mask
if mask.sum() > 0:
valid_advantages = advantages[mask]
advantages[mask] = (valid_advantages - valid_advantages.mean()) / (valid_advantages.std() + 1e-8)
# 计算概率比率
ratio = torch.exp(new_log_probs - old_log_probs)
# 只在有效位置计算损失
valid_ratio = ratio[mask]
valid_advantages = advantages[mask]
valid_returns = returns[mask]
valid_values = values[mask]
valid_entropy = entropy[mask]
# 策略损失 (PPO裁剪)
surr1 = valid_ratio * valid_advantages
surr2 = torch.clamp(valid_ratio, 1.0 - self.clip_ratio, 1.0 + self.clip_ratio) * valid_advantages
policy_loss = -torch.min(surr1, surr2).mean()
# 价值函数损失
value_loss = F.mse_loss(valid_values, valid_returns)
# 熵损失
entropy_loss = -valid_entropy.mean()
# 总损失
total_loss = policy_loss + self.vf_coef * value_loss + self.ent_coef * entropy_loss
# 计算统计信息
clipped_fraction = (torch.abs(valid_ratio - 1.0) > self.clip_ratio).float().mean()
return {
'total_loss': total_loss,
'policy_loss': policy_loss,
'value_loss': value_loss,
'entropy_loss': entropy_loss,
'ratio_mean': valid_ratio.mean(),
'ratio_std': valid_ratio.std(),
'clipped_fraction': clipped_fraction,
'advantages_mean': valid_advantages.mean(),
'advantages_std': valid_advantages.std(),
'returns_mean': valid_returns.mean(),
'entropy_mean': valid_entropy.mean()
}
def create_reward_function():
"""
创建奖励函数 - 这里用简单的示例
在实际应用中,这可能是人类反馈、自动评估指标等
"""
def reward_fn(texts: List[str], responses: List[str]) -> List[List[float]]:
rewards = []
for prompt, response in zip(texts, responses):
# 简单的奖励逻辑:鼓励更长的回答,惩罚重复
tokens = response.split()
reward_per_token = []
for i, token in enumerate(tokens):
# 基础奖励
r = 0.1
# 长度奖励
if len(tokens) > 10:
r += 0.1
# 惩罚重复
if i > 0 and token in tokens[:i]:
r -= 0.3
# 鼓励多样性
if len(set(tokens[:i+1])) / (i+1) > 0.8:
r += 0.2
reward_per_token.append(r)
rewards.append(reward_per_token)
return rewards
return reward_fn
def run_ppo_demo():
"""
运行完整的PPO训练demo
"""
print("=" * 80)
print("真实GPT2模型PPO训练Demo")
print("=" * 80)
# 初始化训练器
ppo_trainer = PPOTrainer(model_name="gpt2", clip_ratio=0.2)
reward_fn = create_reward_function()
# 准备训练数据
prompts = [
"The future of artificial intelligence is",
"In a world where technology advances rapidly,",
"The most important skill in the 21st century is",
"Climate change requires us to"
]
print(f"\n训练数据:")
for i, prompt in enumerate(prompts):
print(f"{i+1}. {prompt}")
# 第一步:使用参考模型生成响应
print(f"\n第一步:生成初始响应...")
# 分词
inputs = ppo_trainer.tokenize_texts(prompts, max_length=32)
print(f"输入shape: {inputs['input_ids'].shape}")
# 使用参考模型生成
with torch.no_grad():
ref_logits, ref_log_probs, ref_entropy = ppo_trainer.get_model_outputs(
inputs['input_ids'], inputs['attention_mask'], ppo_trainer.ref_model
)
# 解码生成的文本
generated_texts = []
for i in range(len(prompts)):
tokens = inputs['input_ids'][i].cpu().numpy()
text = ppo_trainer.tokenizer.decode(tokens, skip_special_tokens=True)
generated_texts.append(text)
print(f"\n生成的响应:")
for i, (prompt, response) in enumerate(zip(prompts, generated_texts)):
print(f"{i+1}. 提示: {prompt}")
print(f" 响应: {response}")
print()
# 第二步:计算奖励
print("第二步:计算奖励...")
rewards_list = reward_fn(prompts, generated_texts)
# 将奖励转换为tensor并填充到相同长度
max_len = ref_log_probs.shape[1]
rewards_tensor = torch.zeros_like(ref_log_probs)
for i, reward_seq in enumerate(rewards_list):
length = min(len(reward_seq), max_len)
rewards_tensor[i, :length] = torch.tensor(reward_seq[:length])
print(f"奖励tensor shape: {rewards_tensor.shape}")
print(f"平均奖励: {rewards_tensor.mean().item():.4f}")
# 第三步:计算PPO损失
print(f"\n第三步:计算PPO损失...")
ppo_trainer.policy_model.train()
loss_info = ppo_trainer.compute_ppo_loss(
inputs['input_ids'],
inputs['attention_mask'],
ref_log_probs,
rewards_tensor
)
# 显示详细结果
print(f"\nPPO损失计算结果:")
print("-" * 50)
print(f"总损失: {loss_info['total_loss']:.6f}")
print(f"策略损失: {loss_info['policy_loss']:.6f}")
print(f"价值损失: {loss_info['value_loss']:.6f}")
print(f"熵损失: {loss_info['entropy_loss']:.6f}")
print()
print(f"训练统计:")
print("-" * 50)
print(f"概率比率(均值): {loss_info['ratio_mean']:.6f}")
print(f"概率比率(标准差): {loss_info['ratio_std']:.6f}")
print(f"被裁剪比例: {loss_info['clipped_fraction']:.6f}")
print(f"优势函数(均值): {loss_info['advantages_mean']:.6f}")
print(f"优势函数(标准差): {loss_info['advantages_std']:.6f}")
print(f"回报(均值): {loss_info['returns_mean']:.6f}")
print(f"熵(均值): {loss_info['entropy_mean']:.6f}")
# 第四步:模拟一步优化
print(f"\n第四步:执行一步梯度更新...")
optimizer = torch.optim.AdamW([
{'params': ppo_trainer.policy_model.parameters()},
{'params': ppo_trainer.value_head.parameters()}
], lr=1e-5)
optimizer.zero_grad()
loss_info['total_loss'].backward()
# 梯度裁剪
torch.nn.utils.clip_grad_norm_(ppo_trainer.policy_model.parameters(), 1.0)
torch.nn.utils.clip_grad_norm_(ppo_trainer.value_head.parameters(), 1.0)
optimizer.step()
print("梯度更新完成!")
# 第五步:计算更新后的损失
print(f"\n第五步:计算更新后的损失...")
with torch.no_grad():
new_loss_info = ppo_trainer.compute_ppo_loss(
inputs['input_ids'],
inputs['attention_mask'],
ref_log_probs,
rewards_tensor
)
print(f"\n更新后的损失:")
print("-" * 50)
print(f"总损失: {new_loss_info['total_loss']:.6f} "
f"(变化: {new_loss_info['total_loss'] - loss_info['total_loss']:.6f})")
print(f"策略损失: {new_loss_info['policy_loss']:.6f} "
f"(变化: {new_loss_info['policy_loss'] - loss_info['policy_loss']:.6f})")
print(f"价值损失: {new_loss_info['value_loss']:.6f} "
f"(变化: {new_loss_info['value_loss'] - loss_info['value_loss']:.6f})")
return ppo_trainer, loss_info, new_loss_info
def analyze_model_behavior(ppo_trainer, prompts):
"""
分析模型行为的变化
"""
print(f"\n" + "=" * 80)
print("模型行为分析")
print("=" * 80)
# 生成新的响应
ppo_trainer.policy_model.eval()
inputs = ppo_trainer.tokenize_texts(prompts[:2], max_length=40) # 只用前两个提示
with torch.no_grad():
# 使用参考模型
ref_outputs = ppo_trainer.ref_model.generate(
inputs['input_ids'],
attention_mask=inputs['attention_mask'],
max_length=40,
num_return_sequences=1,
do_sample=True,
temperature=0.8,
pad_token_id=ppo_trainer.tokenizer.eos_token_id
)
# 使用更新后的策略模型
policy_outputs = ppo_trainer.policy_model.generate(
inputs['input_ids'],
attention_mask=inputs['attention_mask'],
max_length=40,
num_return_sequences=1,
do_sample=True,
temperature=0.8,
pad_token_id=ppo_trainer.tokenizer.eos_token_id
)
print("生成文本对比:")
print("-" * 50)
for i in range(len(prompts[:2])):
ref_text = ppo_trainer.tokenizer.decode(ref_outputs[i], skip_special_tokens=True)
policy_text = ppo_trainer.tokenizer.decode(policy_outputs[i], skip_special_tokens=True)
print(f"提示 {i+1}: {prompts[i]}")
print(f"参考模型: {ref_text}")
print(f"训练模型: {policy_text}")
print()
if __name__ == "__main__":
# 设置随机种子
torch.manual_seed(42)
np.random.seed(42)
# try:
# 运行主demo
ppo_trainer, loss_info, new_loss_info = run_ppo_demo()
# 分析模型行为
prompts = [
"The future of artificial intelligence is",
"In a world where technology advances rapidly,",
"The most important skill in the 21st century is",
"Climate change requires us to"
]
analyze_model_behavior(ppo_trainer, prompts)
print(f"\n" + "=" * 80)
print("Demo完成!")
print("这个demo展示了:")
print("1. 真实GPT2模型的加载和使用")
print("2. 实际的文本生成和分词")
print("3. PPO损失的详细计算过程")
print("4. 梯度更新和模型优化")
print("5. 训练前后的模型行为对比")
print("=" * 80)
# except Exception as e:
# print(f"运行出错: {e}")
# print("请确保安装了transformers库: pip install transformers torch")
手撕PPO
13 views