kaikaidai commited on
Commit
c82e0d6
·
verified ·
1 Parent(s): 6ef585c

Added Selene / Selene Mini API

Browse files
Files changed (1) hide show
  1. gen_api_answer.py +50 -43
gen_api_answer.py CHANGED
@@ -15,6 +15,7 @@ from prompts import (
15
  FLOW_JUDGE_PROMPT
16
  )
17
  from transformers import AutoTokenizer
 
18
 
19
  # Initialize clients
20
  anthropic_client = anthropic.Anthropic()
@@ -24,6 +25,10 @@ hf_api_key = os.getenv("HF_API_KEY")
24
  flow_judge_api_key = os.getenv("FLOW_JUDGE_API_KEY")
25
  cohere_client = cohere.ClientV2(os.getenv("CO_API_KEY"))
26
  salesforce_api_key = os.getenv("SALESFORCE_API_KEY")
 
 
 
 
27
  def get_openai_response(model_name, prompt, system_prompt=JUDGE_SYSTEM_PROMPT, max_tokens=500, temperature=0):
28
  """Get response from OpenAI API"""
29
  try:
@@ -110,42 +115,33 @@ def get_prometheus_response(model_name, prompt, system_prompt=None, max_tokens=5
110
  return f"Error with Hugging Face model {model_name}: {str(e)}"
111
 
112
  def get_atla_response(model_name, prompt, system_prompt=None, max_tokens=500, temperature=0.01):
113
- """Get response from HF endpoint for Atla model"""
114
  try:
115
- headers = {
116
- "Accept": "application/json",
117
- "Authorization": f"Bearer {hf_api_key}",
118
- "Content-Type": "application/json"
119
- }
120
-
121
- # Create messages list for chat template
122
- messages = []
123
- if system_prompt:
124
- messages.append({"role": "system", "content": system_prompt})
125
- messages.append({"role": "user", "content": prompt})
126
-
127
- # Apply chat template
128
- model_id = "AtlaAI/Selene-1-Mini-Llama-3.1-8B"
129
- tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_api_key)
130
- formatted_prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
 
 
 
131
 
132
- payload = {
133
- "inputs": formatted_prompt,
134
- "parameters": {
135
- "max_new_tokens": max_tokens,
136
- "return_full_text": False,
137
- "temperature": temperature,
138
- "seed": 42,
139
- "add_generation_prompt": True
140
- }
141
  }
142
-
143
- response = requests.post(
144
- "https://bkp9p28gri93egqh.us-east-1.aws.endpoints.huggingface.cloud",
145
- headers=headers,
146
- json=payload
147
- )
148
- return response.json()[0]["generated_text"]
149
  except Exception as e:
150
  return f"Error with Atla model {model_name}: {str(e)}"
151
 
@@ -321,9 +317,16 @@ def get_model_response(
321
  api_model, final_prompt, system_prompt, max_tokens, temperature = 0.01
322
  )
323
  elif organization == "Atla":
324
- return get_atla_response(
325
- api_model, final_prompt, system_prompt, max_tokens, temperature = 0.01
326
  )
 
 
 
 
 
 
 
327
  elif organization == "Cohere":
328
  return get_cohere_response(
329
  api_model, final_prompt, system_prompt, max_tokens, temperature
@@ -350,6 +353,10 @@ def parse_model_response(response):
350
  # Debug print
351
  print(f"Raw model response: {response}")
352
 
 
 
 
 
353
  # If response is already a dictionary, use it directly
354
  if isinstance(response, dict):
355
  return str(response.get("result", "N/A")), response.get("feedback", "N/A")
@@ -359,10 +366,10 @@ def parse_model_response(response):
359
  data = json.loads(response)
360
  return str(data.get("result", "N/A")), data.get("feedback", "N/A")
361
  except json.JSONDecodeError:
362
- # If that fails, check if this is a Salesforce response (which uses ATLA format)
363
  if "**Reasoning:**" in response or "**Result:**" in response:
364
- # Use ATLA parser for Salesforce responses
365
- return atla_parse_model_response(response)
366
 
367
  # Otherwise try to find JSON within the response
368
  json_match = re.search(r"{.*}", response, re.DOTALL)
@@ -443,10 +450,10 @@ def prometheus_parse_model_response(output):
443
  print(f"Failed to parse response: {str(e)}")
444
  return "Error", f"Exception during parsing: {str(e)}"
445
 
446
- def atla_parse_model_response(output):
447
- """Parse response from ATLA model"""
448
  try:
449
- print(f"Raw Atla model response: {output}")
450
  output = output.strip()
451
 
452
  # Look for the Reasoning and Result sections
@@ -458,10 +465,10 @@ def atla_parse_model_response(output):
458
  score = result_match.group(1)
459
  return str(score), feedback
460
 
461
- return "Error", f"Failed to parse ATLA response format: {output}"
462
 
463
  except Exception as e:
464
- print(f"Failed to parse ATLA response: {str(e)}")
465
  return "Error", f"Exception during parsing: {str(e)}"
466
 
467
  def flow_judge_parse_model_response(output):
 
15
  FLOW_JUDGE_PROMPT
16
  )
17
  from transformers import AutoTokenizer
18
+ from atla import Atla
19
 
20
  # Initialize clients
21
  anthropic_client = anthropic.Anthropic()
 
25
  flow_judge_api_key = os.getenv("FLOW_JUDGE_API_KEY")
26
  cohere_client = cohere.ClientV2(os.getenv("CO_API_KEY"))
27
  salesforce_api_key = os.getenv("SALESFORCE_API_KEY")
28
+
29
+ # Initialize Atla client
30
+ atla_client = Atla()
31
+
32
  def get_openai_response(model_name, prompt, system_prompt=JUDGE_SYSTEM_PROMPT, max_tokens=500, temperature=0):
33
  """Get response from OpenAI API"""
34
  try:
 
115
  return f"Error with Hugging Face model {model_name}: {str(e)}"
116
 
117
  def get_atla_response(model_name, prompt, system_prompt=None, max_tokens=500, temperature=0.01):
118
+ """Get response from Atla API"""
119
  try:
120
+ # Extract components from the prompt data
121
+ model_input = prompt.get('human_input', '')
122
+ model_output = prompt.get('ai_response', '')
123
+ expected_output = prompt.get('ground_truth_input', '')
124
+ evaluation_criteria = prompt.get('eval_criteria', '')
125
+
126
+ # Set model_id based on the model name
127
+ if "Mini" in model_name:
128
+ model_id = "atla-selene-mini"
129
+ else:
130
+ model_id = "atla-selene"
131
+
132
+ response = atla_client.evaluation.create(
133
+ model_id=model_id,
134
+ model_input=model_input,
135
+ model_output=model_output,
136
+ expected_model_output=expected_output if expected_output else None,
137
+ evaluation_criteria=evaluation_criteria,
138
+ )
139
 
140
+ # Return the score and critique directly
141
+ return {
142
+ "score": response.result.evaluation.score,
143
+ "critique": response.result.evaluation.critique
 
 
 
 
 
144
  }
 
 
 
 
 
 
 
145
  except Exception as e:
146
  return f"Error with Atla model {model_name}: {str(e)}"
147
 
 
317
  api_model, final_prompt, system_prompt, max_tokens, temperature = 0.01
318
  )
319
  elif organization == "Atla":
320
+ response = get_atla_response(
321
+ api_model, prompt_data, system_prompt, max_tokens, temperature
322
  )
323
+ # Response now contains score and critique directly
324
+ if isinstance(response, dict) and 'score' in response and 'critique' in response:
325
+ score = str(response['score'])
326
+ critique = response['critique']
327
+ return score, critique
328
+ else:
329
+ return "Error", str(response)
330
  elif organization == "Cohere":
331
  return get_cohere_response(
332
  api_model, final_prompt, system_prompt, max_tokens, temperature
 
353
  # Debug print
354
  print(f"Raw model response: {response}")
355
 
356
+ # If response is already a tuple (from Atla/Salesforce), use it directly
357
+ if isinstance(response, tuple):
358
+ return response
359
+
360
  # If response is already a dictionary, use it directly
361
  if isinstance(response, dict):
362
  return str(response.get("result", "N/A")), response.get("feedback", "N/A")
 
366
  data = json.loads(response)
367
  return str(data.get("result", "N/A")), data.get("feedback", "N/A")
368
  except json.JSONDecodeError:
369
+ # If that fails, check if this is a Salesforce response
370
  if "**Reasoning:**" in response or "**Result:**" in response:
371
+ # Use ATLA parser for Salesforce responses only
372
+ return salesforce_parse_model_response(response)
373
 
374
  # Otherwise try to find JSON within the response
375
  json_match = re.search(r"{.*}", response, re.DOTALL)
 
450
  print(f"Failed to parse response: {str(e)}")
451
  return "Error", f"Exception during parsing: {str(e)}"
452
 
453
+ def salesforce_parse_model_response(output):
454
+ """Parse response from Salesforce model"""
455
  try:
456
+ print(f"Raw Salesforce model response: {output}")
457
  output = output.strip()
458
 
459
  # Look for the Reasoning and Result sections
 
465
  score = result_match.group(1)
466
  return str(score), feedback
467
 
468
+ return "Error", f"Failed to parse Salesforce response format: {output}"
469
 
470
  except Exception as e:
471
+ print(f"Failed to parse Salesforce response: {str(e)}")
472
  return "Error", f"Exception during parsing: {str(e)}"
473
 
474
  def flow_judge_parse_model_response(output):