import argparse import re import torch import uvicorn from fastapi import FastAPI, Request from fastapi.responses import JSONResponse from transformers import AutoTokenizer import asyncio from collections import defaultdict import json from openai import AsyncOpenAI import time import math # Set OpenAI's API key and API base to use vLLM's API server. # for free-form including multiple-choice PROMPT_critic_updated = ''' Given a problem, determine whether the final answer in the provided (incomplete) solution process matches the reference answer. The reference answer may be one single option character (e.g., A, B, C, D), a numerical value, an expression, or a list of answers if multiple questions are involved. **The reference answer may be in Chinese or another language, but your evaluation should be language-agnostic.** Your task: - Compare the final output of the solution process with the reference answer. - If they **match exactly**, output **YES**. - If they **do not match**, output **NO**. - If the solution process is unclear, incomplete, or ambiguous, assume it is incorrect and output **NO**. Your output must be strictly **'YES'** or **'NO'**, with no additional words, punctuation, or explanation. --- **Question:** {question} **Solution Process (Final Step Only):** {response} **Reference Answer:** {reference} **Output:** ''' def parse_im_sections(text): # Match all sections between <|im_start|> and <|im_end|> sections = re.findall(r"<\|im_start\|>(.*?)<\|im_end\|>", text, re.DOTALL) parsed = {} for section in sections: try: # Split the role and content role, content = section.split("\n", 1) parsed[role.strip()] = content.strip() except ValueError: print(f"Skipping malformed section: {section}") return parsed def extract_last_non_empty_line(text, role="assistant"): # Extract the last non-empty line from assistant's content pattern = fr"<\|im_start\|>{role}(.*?)(?:<\|im_start\|>|<\|endoftext\|>|<\|eot_id\|>|$)" match = re.search(pattern, text, re.DOTALL) if match: content = match.group(1).strip() # Get the last non-empty line lines = [line for line in content.splitlines() if line.strip()] if lines: last_non_empty_line=lines[-1] else: return "" return last_non_empty_line return "" def reward_normalization(rewards): if len(rewards) == 1: return [0.0] rewards = torch.tensor(rewards, dtype=torch.float64) if rewards.std() == 0: normalized_rewards = torch.zeros_like(rewards) else: normalized_rewards = (rewards - rewards.mean()) / rewards.std() return normalized_rewards.tolist() def strip_sequence(text, pad_token, eos_token): pad_token_escaped = re.escape(pad_token) eos_token_escaped = re.escape(eos_token) pattern = f"^({eos_token_escaped}|{pad_token_escaped})+" text = re.sub(pattern, "", text) pattern = f"({eos_token_escaped}|{pad_token_escaped})+$" text = re.sub(pattern, "", text) return text def group_reward_normalization(rewards, n_samples_per_prompt=4): rewards = torch.tensor(rewards, dtype=torch.float64) rewards = rewards.reshape(-1, n_samples_per_prompt) mean = rewards.mean(dim=-1, keepdim=True) std = rewards.std(dim=-1, keepdim=True) normalized_rewards = torch.where(std == 0, torch.zeros_like(rewards), (rewards - mean) / std) return normalized_rewards.flatten().tolist() class RewardModelProxy: def __init__(self, args): self.tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path, trust_remote_code=True) self.normalize_reward = args.normalize_reward self.group_normalize_reward = args.group_normalize_reward self.qa_dict = defaultdict(str) self.load_dict(args.answer_path) self.temperature = 0 self.stop=[self.tokenizer.eos_token,"<|im_end|>"] self.max_tokens=1 self.prob_reward=args.prob_reward self.log_path=args.log_path self.vllm_model=args.vllm_model def load_dict(self, path): # Initialize self.qa_dict with open(path, "r", encoding="utf-8") as file: data = json.load(file) for unit in data: question = unit["query"][1]["content"] label = unit["label"] self.qa_dict[question] = label if self.qa_dict: sample_question, sample_label = next(iter(self.qa_dict.items())) print("Sample Question:", sample_question) print("Sample Label:", sample_label) else: print("qa_dict is empty.") async def process_sample(self,query): query = strip_sequence(query, self.tokenizer.pad_token, self.tokenizer.eos_token)+ self.tokenizer.eos_token question = parse_im_sections(query)["user"] answer = extract_last_non_empty_line(query, role="assistant") if not answer.strip(): return 0.0 else: prompt_question = PROMPT_critic_updated.format(question=question, reference=self.qa_dict[question], response=answer) return await self.get_reward_from_vllm(prompt_question) async def get_reward_from_vllm(self, query): """Retrieve model judgment reward (with probability analysis)""" max_retries = 10 delay=10 for attempt in range(max_retries): try: response = await client.chat.completions.create( model=self.vllm_model, messages=[ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": query}, ], temperature=self.temperature, max_tokens=self.max_tokens, stop=self.stop, logprobs=True, top_logprobs=10 # Get top 10 token probabilities ) return self.calculate_reward_from_logprobs(response) except Exception as e: print(f"Attempt {attempt+1} failed: {str(e)}, retrying in {delay} seconds...") await asyncio.sleep(delay) print(f"Failed after {max_retries} retries, query content: {query[:200]}...") return 0.0 # Return baseline value on failure def calculate_reward_from_logprobs(self, response): """Calculate normalized reward based on log probabilities""" # Extract probabilities of all possible tokens logprobs = response.choices[0].logprobs.content[0].top_logprobs token_probs = {token.token: math.exp(token.logprob) for token in logprobs} # Combine probabilities of YES/NO (case-insensitive) yes_prob = sum(prob for token, prob in token_probs.items() if token.lower().strip()=="yes") no_prob = sum(prob for token, prob in token_probs.items()if token.lower().strip()=="no") total = yes_prob + no_prob if total == 0: return 0.0 # Return baseline value when no valid judgment if self.prob_reward: print(yes_prob/total) return yes_prob / total # Normalized probability return 1.0 if yes_prob > no_prob else 0.0 # Hard judgment mode async def get_reward(self, queries): print("Processing queries[0]: {}".format(queries[0])) tasks = [self.process_sample(query) for query in queries] scores = await asyncio.gather(*tasks) print("Generated scores: {}".format(scores)) if self.log_path: with open(self.log_path, 'a', encoding='utf-8') as f: unit = { "query_list": queries if isinstance(queries, list) else [], "hard_score_list": scores if isinstance(scores, list) else [] } json.dump(unit, f, ensure_ascii=False) f.write('\n') if self.normalize_reward: return reward_normalization(scores) elif self.group_normalize_reward: return group_reward_normalization(scores) else: return scores if __name__ == "__main__": parser = argparse.ArgumentParser() # Reward Model parser.add_argument("--tokenizer_path", type=str, default=None) parser.add_argument("--answer_path", type=str, default=None) parser.add_argument("--prob_reward", action="store_true", default=False) parser.add_argument("--normalize_reward", action="store_true", default=False, help="Enable Reward Normazation") parser.add_argument("--group_normalize_reward", action="store_true", default=False, help="Enable Group Reward Normazation") parser.add_argument("--port", type=int, default=5000, help="Port number for the server") parser.add_argument("--host", type=str, default="0.0.0.0", help="IP for the server") parser.add_argument("--log_path", type=str, default=None) parser.add_argument("--vllm_url", type=str, default=None) parser.add_argument("--vllm_model", type=str, default=None) args = parser.parse_args() openai_api_key = "EMPTY" openai_api_base = args.vllm_url client = AsyncOpenAI( api_key=openai_api_key, base_url=openai_api_base, ) # Server setup reward_model = RewardModelProxy(args) app = FastAPI() @app.post("/get_reward") async def get_reward(request: Request): data = await request.json() queries = data.get("query") rewards = await reward_model.get_reward(queries) result = {"rewards": rewards} print(f"Sent JSON response: {result}") return JSONResponse(result) uvicorn.run(app, host=args.host, port=args.port, log_level="info")