手撕PPO


import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import GPT2LMHeadModel, GPT2Tokenizer, GPT2Config
import numpy as np
from typing import Dict, List, Tuple
import warnings
warnings.filterwarnings("ignore")

class PPOTrainer:
    """
    基于真实GPT2模型的PPO训练器
    """
    def __init__(self, model_name="gpt2", clip_ratio=0.2, vf_coef=0.5, ent_coef=0.01):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"使用设备: {self.device}")

        # 加载模型和分词器
        self.tokenizer = GPT2Tokenizer.from_pretrained(model_name)
        self.tokenizer.pad_token = self.tokenizer.eos_token

        # 策略网络 (生成模型)
        self.policy_model = GPT2LMHeadModel.from_pretrained(model_name).to(self.device)

        # 参考模型 (固定的旧策略)
        self.ref_model = GPT2LMHeadModel.from_pretrained(model_name).to(self.device)
        for param in self.ref_model.parameters():
            param.requires_grad = False

        # 价值网络 (简单的线性层加在GPT2上)
        self.value_head = nn.Linear(self.policy_model.config.hidden_size, 1).to(self.device)

        # PPO参数
        self.clip_ratio = clip_ratio
        self.vf_coef = vf_coef
        self.ent_coef = ent_coef

        print(f"模型加载完成: {model_name}")
        print(f"模型参数量: {sum(p.numel() for p in self.policy_model.parameters()):,}")

    def tokenize_texts(self, texts: List[str], max_length: int = 64) -> Dict[str, torch.Tensor]:
        """文本分词"""
        encoded = self.tokenizer(
            texts, 
            return_tensors="pt", 
            padding=True, 
            truncation=True,
            max_length=max_length
        )
        return {k: v.to(self.device) for k, v in encoded.items()}

    def get_model_outputs(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, 
                         model: GPT2LMHeadModel) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        获取模型输出: logits, 对数概率, 熵
        """
        with torch.no_grad() if model == self.ref_model else torch.enable_grad():
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            logits = outputs.logits[:, :-1]  # 去掉最后一个token的logits

            # 计算对数概率
            log_probs = F.log_softmax(logits, dim=-1)

            # 获取实际选择的token的对数概率
            target_tokens = input_ids[:, 1:]  # 去掉第一个token (通常是BOS)
            selected_log_probs = torch.gather(log_probs, -1, target_tokens.unsqueeze(-1)).squeeze(-1)

            # 计算熵
            probs = F.softmax(logits, dim=-1)
            entropy = -(probs * log_probs).sum(dim=-1)

            return logits, selected_log_probs, entropy

    def get_value_estimates(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
        """获取价值函数估计"""
        outputs = self.policy_model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
        hidden_states = outputs.hidden_states[-1]  # 最后一层隐藏状态
        values = self.value_head(hidden_states)[:, :-1].squeeze(-1)  # 去掉最后一个位置
        return values

    def compute_advantages_and_returns(self, rewards: torch.Tensor, values: torch.Tensor, 
                                     attention_mask: torch.Tensor, gamma: float = 0.99, 
                                     lam: float = 0.95) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        使用GAE计算优势函数和回报
        """
        batch_size, seq_len = rewards.shape
        advantages = torch.zeros_like(rewards)
        returns = torch.zeros_like(rewards)

        # 从后往前计算
        gae = 0
        for t in reversed(range(seq_len)):
            mask = attention_mask[:, t + 1] if t + 1 < seq_len else torch.zeros(batch_size).to(self.device)
            next_value = values[:, t + 1] if t + 1 < seq_len else torch.zeros(batch_size).to(self.device)

            delta = rewards[:, t] + gamma * next_value * mask - values[:, t]
            gae = delta + gamma * lam * mask * gae
            advantages[:, t] = gae
            returns[:, t] = advantages[:, t] + values[:, t]

        return advantages, returns

    def compute_ppo_loss(self, input_ids: torch.Tensor, attention_mask: torch.Tensor,
                        old_log_probs: torch.Tensor, rewards: torch.Tensor) -> Dict[str, torch.Tensor]:
        """
        计算PPO损失
        """
        # 获取当前策略的输出
        policy_logits, new_log_probs, entropy = self.get_model_outputs(
            input_ids, attention_mask, self.policy_model
        )

        # 获取价值估计
        values = self.get_value_estimates(input_ids, attention_mask)

        # 计算优势函数和回报
        advantages, returns = self.compute_advantages_and_returns(
            rewards, values, attention_mask[:, :-1]  # 调整attention_mask尺寸
        )

        # 标准化优势函数
        mask = attention_mask[:, 1:].bool()  # 有效位置的mask
        if mask.sum() > 0:
            valid_advantages = advantages[mask]
            advantages[mask] = (valid_advantages - valid_advantages.mean()) / (valid_advantages.std() + 1e-8)

        # 计算概率比率
        ratio = torch.exp(new_log_probs - old_log_probs)

        # 只在有效位置计算损失
        valid_ratio = ratio[mask]
        valid_advantages = advantages[mask]
        valid_returns = returns[mask]
        valid_values = values[mask]
        valid_entropy = entropy[mask]

        # 策略损失 (PPO裁剪)
        surr1 = valid_ratio * valid_advantages
        surr2 = torch.clamp(valid_ratio, 1.0 - self.clip_ratio, 1.0 + self.clip_ratio) * valid_advantages
        policy_loss = -torch.min(surr1, surr2).mean()

        # 价值函数损失
        value_loss = F.mse_loss(valid_values, valid_returns)

        # 熵损失
        entropy_loss = -valid_entropy.mean()

        # 总损失
        total_loss = policy_loss + self.vf_coef * value_loss + self.ent_coef * entropy_loss

        # 计算统计信息
        clipped_fraction = (torch.abs(valid_ratio - 1.0) > self.clip_ratio).float().mean()

        return {
            'total_loss': total_loss,
            'policy_loss': policy_loss,
            'value_loss': value_loss,
            'entropy_loss': entropy_loss,
            'ratio_mean': valid_ratio.mean(),
            'ratio_std': valid_ratio.std(),
            'clipped_fraction': clipped_fraction,
            'advantages_mean': valid_advantages.mean(),
            'advantages_std': valid_advantages.std(),
            'returns_mean': valid_returns.mean(),
            'entropy_mean': valid_entropy.mean()
        }

def create_reward_function():
    """
    创建奖励函数 - 这里用简单的示例
    在实际应用中,这可能是人类反馈、自动评估指标等
    """
    def reward_fn(texts: List[str], responses: List[str]) -> List[List[float]]:
        rewards = []
        for prompt, response in zip(texts, responses):
            # 简单的奖励逻辑:鼓励更长的回答,惩罚重复
            tokens = response.split()
            reward_per_token = []

            for i, token in enumerate(tokens):
                # 基础奖励
                r = 0.1

                # 长度奖励
                if len(tokens) > 10:
                    r += 0.1

                # 惩罚重复
                if i > 0 and token in tokens[:i]:
                    r -= 0.3

                # 鼓励多样性
                if len(set(tokens[:i+1])) / (i+1) > 0.8:
                    r += 0.2

                reward_per_token.append(r)

            rewards.append(reward_per_token)

        return rewards

    return reward_fn

def run_ppo_demo():
    """
    运行完整的PPO训练demo
    """
    print("=" * 80)
    print("真实GPT2模型PPO训练Demo")
    print("=" * 80)

    # 初始化训练器
    ppo_trainer = PPOTrainer(model_name="gpt2", clip_ratio=0.2)
    reward_fn = create_reward_function()

    # 准备训练数据
    prompts = [
        "The future of artificial intelligence is",
        "In a world where technology advances rapidly,",
        "The most important skill in the 21st century is",
        "Climate change requires us to"
    ]

    print(f"\n训练数据:")
    for i, prompt in enumerate(prompts):
        print(f"{i+1}. {prompt}")

    # 第一步:使用参考模型生成响应
    print(f"\n第一步:生成初始响应...")

    # 分词
    inputs = ppo_trainer.tokenize_texts(prompts, max_length=32)
    print(f"输入shape: {inputs['input_ids'].shape}")

    # 使用参考模型生成
    with torch.no_grad():
        ref_logits, ref_log_probs, ref_entropy = ppo_trainer.get_model_outputs(
            inputs['input_ids'], inputs['attention_mask'], ppo_trainer.ref_model
        )

    # 解码生成的文本
    generated_texts = []
    for i in range(len(prompts)):
        tokens = inputs['input_ids'][i].cpu().numpy()
        text = ppo_trainer.tokenizer.decode(tokens, skip_special_tokens=True)
        generated_texts.append(text)

    print(f"\n生成的响应:")
    for i, (prompt, response) in enumerate(zip(prompts, generated_texts)):
        print(f"{i+1}. 提示: {prompt}")
        print(f"   响应: {response}")
        print()

    # 第二步:计算奖励
    print("第二步:计算奖励...")
    rewards_list = reward_fn(prompts, generated_texts)

    # 将奖励转换为tensor并填充到相同长度
    max_len = ref_log_probs.shape[1]
    rewards_tensor = torch.zeros_like(ref_log_probs)

    for i, reward_seq in enumerate(rewards_list):
        length = min(len(reward_seq), max_len)
        rewards_tensor[i, :length] = torch.tensor(reward_seq[:length])

    print(f"奖励tensor shape: {rewards_tensor.shape}")
    print(f"平均奖励: {rewards_tensor.mean().item():.4f}")

    # 第三步:计算PPO损失
    print(f"\n第三步:计算PPO损失...")

    ppo_trainer.policy_model.train()
    loss_info = ppo_trainer.compute_ppo_loss(
        inputs['input_ids'],
        inputs['attention_mask'], 
        ref_log_probs,
        rewards_tensor
    )

    # 显示详细结果
    print(f"\nPPO损失计算结果:")
    print("-" * 50)
    print(f"总损失:          {loss_info['total_loss']:.6f}")
    print(f"策略损失:        {loss_info['policy_loss']:.6f}")
    print(f"价值损失:        {loss_info['value_loss']:.6f}")
    print(f"熵损失:          {loss_info['entropy_loss']:.6f}")
    print()
    print(f"训练统计:")
    print("-" * 50)
    print(f"概率比率(均值):  {loss_info['ratio_mean']:.6f}")
    print(f"概率比率(标准差): {loss_info['ratio_std']:.6f}")
    print(f"被裁剪比例:      {loss_info['clipped_fraction']:.6f}")
    print(f"优势函数(均值):  {loss_info['advantages_mean']:.6f}")
    print(f"优势函数(标准差): {loss_info['advantages_std']:.6f}")
    print(f"回报(均值):      {loss_info['returns_mean']:.6f}")
    print(f"熵(均值):        {loss_info['entropy_mean']:.6f}")

    # 第四步:模拟一步优化
    print(f"\n第四步:执行一步梯度更新...")

    optimizer = torch.optim.AdamW([
        {'params': ppo_trainer.policy_model.parameters()},
        {'params': ppo_trainer.value_head.parameters()}
    ], lr=1e-5)

    optimizer.zero_grad()
    loss_info['total_loss'].backward()

    # 梯度裁剪
    torch.nn.utils.clip_grad_norm_(ppo_trainer.policy_model.parameters(), 1.0)
    torch.nn.utils.clip_grad_norm_(ppo_trainer.value_head.parameters(), 1.0)

    optimizer.step()

    print("梯度更新完成!")

    # 第五步:计算更新后的损失
    print(f"\n第五步:计算更新后的损失...")

    with torch.no_grad():
        new_loss_info = ppo_trainer.compute_ppo_loss(
            inputs['input_ids'],
            inputs['attention_mask'],
            ref_log_probs,
            rewards_tensor
        )

    print(f"\n更新后的损失:")
    print("-" * 50)
    print(f"总损失:      {new_loss_info['total_loss']:.6f} "
          f"(变化: {new_loss_info['total_loss'] - loss_info['total_loss']:.6f})")
    print(f"策略损失:    {new_loss_info['policy_loss']:.6f} "
          f"(变化: {new_loss_info['policy_loss'] - loss_info['policy_loss']:.6f})")
    print(f"价值损失:    {new_loss_info['value_loss']:.6f} "
          f"(变化: {new_loss_info['value_loss'] - loss_info['value_loss']:.6f})")

    return ppo_trainer, loss_info, new_loss_info

def analyze_model_behavior(ppo_trainer, prompts):
    """
    分析模型行为的变化
    """
    print(f"\n" + "=" * 80)
    print("模型行为分析")
    print("=" * 80)

    # 生成新的响应
    ppo_trainer.policy_model.eval()

    inputs = ppo_trainer.tokenize_texts(prompts[:2], max_length=40)  # 只用前两个提示

    with torch.no_grad():
        # 使用参考模型
        ref_outputs = ppo_trainer.ref_model.generate(
            inputs['input_ids'],
            attention_mask=inputs['attention_mask'],
            max_length=40,
            num_return_sequences=1,
            do_sample=True,
            temperature=0.8,
            pad_token_id=ppo_trainer.tokenizer.eos_token_id
        )

        # 使用更新后的策略模型
        policy_outputs = ppo_trainer.policy_model.generate(
            inputs['input_ids'],
            attention_mask=inputs['attention_mask'],
            max_length=40,
            num_return_sequences=1,
            do_sample=True,
            temperature=0.8,
            pad_token_id=ppo_trainer.tokenizer.eos_token_id
        )

    print("生成文本对比:")
    print("-" * 50)

    for i in range(len(prompts[:2])):
        ref_text = ppo_trainer.tokenizer.decode(ref_outputs[i], skip_special_tokens=True)
        policy_text = ppo_trainer.tokenizer.decode(policy_outputs[i], skip_special_tokens=True)

        print(f"提示 {i+1}: {prompts[i]}")
        print(f"参考模型: {ref_text}")
        print(f"训练模型: {policy_text}")
        print()

if __name__ == "__main__":
    # 设置随机种子
    torch.manual_seed(42)
    np.random.seed(42)

    # try:
    # 运行主demo
    ppo_trainer, loss_info, new_loss_info = run_ppo_demo()

    # 分析模型行为
    prompts = [
        "The future of artificial intelligence is",
        "In a world where technology advances rapidly,",
        "The most important skill in the 21st century is",
        "Climate change requires us to"
    ]
    analyze_model_behavior(ppo_trainer, prompts)

    print(f"\n" + "=" * 80)
    print("Demo完成!")
    print("这个demo展示了:")
    print("1. 真实GPT2模型的加载和使用")
    print("2. 实际的文本生成和分词")
    print("3. PPO损失的详细计算过程")
    print("4. 梯度更新和模型优化")
    print("5. 训练前后的模型行为对比")
    print("=" * 80)

    # except Exception as e:
    #     print(f"运行出错: {e}")
    #     print("请确保安装了transformers库: pip install transformers torch")