Spaces:
Running
Running
Added Selene / Selene Mini API
Browse files- gen_api_answer.py +50 -43
gen_api_answer.py
CHANGED
@@ -15,6 +15,7 @@ from prompts import (
|
|
15 |
FLOW_JUDGE_PROMPT
|
16 |
)
|
17 |
from transformers import AutoTokenizer
|
|
|
18 |
|
19 |
# Initialize clients
|
20 |
anthropic_client = anthropic.Anthropic()
|
@@ -24,6 +25,10 @@ hf_api_key = os.getenv("HF_API_KEY")
|
|
24 |
flow_judge_api_key = os.getenv("FLOW_JUDGE_API_KEY")
|
25 |
cohere_client = cohere.ClientV2(os.getenv("CO_API_KEY"))
|
26 |
salesforce_api_key = os.getenv("SALESFORCE_API_KEY")
|
|
|
|
|
|
|
|
|
27 |
def get_openai_response(model_name, prompt, system_prompt=JUDGE_SYSTEM_PROMPT, max_tokens=500, temperature=0):
|
28 |
"""Get response from OpenAI API"""
|
29 |
try:
|
@@ -110,42 +115,33 @@ def get_prometheus_response(model_name, prompt, system_prompt=None, max_tokens=5
|
|
110 |
return f"Error with Hugging Face model {model_name}: {str(e)}"
|
111 |
|
112 |
def get_atla_response(model_name, prompt, system_prompt=None, max_tokens=500, temperature=0.01):
|
113 |
-
"""Get response from
|
114 |
try:
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
#
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
|
|
|
|
|
|
131 |
|
132 |
-
|
133 |
-
|
134 |
-
"
|
135 |
-
|
136 |
-
"return_full_text": False,
|
137 |
-
"temperature": temperature,
|
138 |
-
"seed": 42,
|
139 |
-
"add_generation_prompt": True
|
140 |
-
}
|
141 |
}
|
142 |
-
|
143 |
-
response = requests.post(
|
144 |
-
"https://bkp9p28gri93egqh.us-east-1.aws.endpoints.huggingface.cloud",
|
145 |
-
headers=headers,
|
146 |
-
json=payload
|
147 |
-
)
|
148 |
-
return response.json()[0]["generated_text"]
|
149 |
except Exception as e:
|
150 |
return f"Error with Atla model {model_name}: {str(e)}"
|
151 |
|
@@ -321,9 +317,16 @@ def get_model_response(
|
|
321 |
api_model, final_prompt, system_prompt, max_tokens, temperature = 0.01
|
322 |
)
|
323 |
elif organization == "Atla":
|
324 |
-
|
325 |
-
api_model,
|
326 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
327 |
elif organization == "Cohere":
|
328 |
return get_cohere_response(
|
329 |
api_model, final_prompt, system_prompt, max_tokens, temperature
|
@@ -350,6 +353,10 @@ def parse_model_response(response):
|
|
350 |
# Debug print
|
351 |
print(f"Raw model response: {response}")
|
352 |
|
|
|
|
|
|
|
|
|
353 |
# If response is already a dictionary, use it directly
|
354 |
if isinstance(response, dict):
|
355 |
return str(response.get("result", "N/A")), response.get("feedback", "N/A")
|
@@ -359,10 +366,10 @@ def parse_model_response(response):
|
|
359 |
data = json.loads(response)
|
360 |
return str(data.get("result", "N/A")), data.get("feedback", "N/A")
|
361 |
except json.JSONDecodeError:
|
362 |
-
# If that fails, check if this is a Salesforce response
|
363 |
if "**Reasoning:**" in response or "**Result:**" in response:
|
364 |
-
# Use ATLA parser for Salesforce responses
|
365 |
-
return
|
366 |
|
367 |
# Otherwise try to find JSON within the response
|
368 |
json_match = re.search(r"{.*}", response, re.DOTALL)
|
@@ -443,10 +450,10 @@ def prometheus_parse_model_response(output):
|
|
443 |
print(f"Failed to parse response: {str(e)}")
|
444 |
return "Error", f"Exception during parsing: {str(e)}"
|
445 |
|
446 |
-
def
|
447 |
-
"""Parse response from
|
448 |
try:
|
449 |
-
print(f"Raw
|
450 |
output = output.strip()
|
451 |
|
452 |
# Look for the Reasoning and Result sections
|
@@ -458,10 +465,10 @@ def atla_parse_model_response(output):
|
|
458 |
score = result_match.group(1)
|
459 |
return str(score), feedback
|
460 |
|
461 |
-
return "Error", f"Failed to parse
|
462 |
|
463 |
except Exception as e:
|
464 |
-
print(f"Failed to parse
|
465 |
return "Error", f"Exception during parsing: {str(e)}"
|
466 |
|
467 |
def flow_judge_parse_model_response(output):
|
|
|
15 |
FLOW_JUDGE_PROMPT
|
16 |
)
|
17 |
from transformers import AutoTokenizer
|
18 |
+
from atla import Atla
|
19 |
|
20 |
# Initialize clients
|
21 |
anthropic_client = anthropic.Anthropic()
|
|
|
25 |
flow_judge_api_key = os.getenv("FLOW_JUDGE_API_KEY")
|
26 |
cohere_client = cohere.ClientV2(os.getenv("CO_API_KEY"))
|
27 |
salesforce_api_key = os.getenv("SALESFORCE_API_KEY")
|
28 |
+
|
29 |
+
# Initialize Atla client
|
30 |
+
atla_client = Atla()
|
31 |
+
|
32 |
def get_openai_response(model_name, prompt, system_prompt=JUDGE_SYSTEM_PROMPT, max_tokens=500, temperature=0):
|
33 |
"""Get response from OpenAI API"""
|
34 |
try:
|
|
|
115 |
return f"Error with Hugging Face model {model_name}: {str(e)}"
|
116 |
|
117 |
def get_atla_response(model_name, prompt, system_prompt=None, max_tokens=500, temperature=0.01):
|
118 |
+
"""Get response from Atla API"""
|
119 |
try:
|
120 |
+
# Extract components from the prompt data
|
121 |
+
model_input = prompt.get('human_input', '')
|
122 |
+
model_output = prompt.get('ai_response', '')
|
123 |
+
expected_output = prompt.get('ground_truth_input', '')
|
124 |
+
evaluation_criteria = prompt.get('eval_criteria', '')
|
125 |
+
|
126 |
+
# Set model_id based on the model name
|
127 |
+
if "Mini" in model_name:
|
128 |
+
model_id = "atla-selene-mini"
|
129 |
+
else:
|
130 |
+
model_id = "atla-selene"
|
131 |
+
|
132 |
+
response = atla_client.evaluation.create(
|
133 |
+
model_id=model_id,
|
134 |
+
model_input=model_input,
|
135 |
+
model_output=model_output,
|
136 |
+
expected_model_output=expected_output if expected_output else None,
|
137 |
+
evaluation_criteria=evaluation_criteria,
|
138 |
+
)
|
139 |
|
140 |
+
# Return the score and critique directly
|
141 |
+
return {
|
142 |
+
"score": response.result.evaluation.score,
|
143 |
+
"critique": response.result.evaluation.critique
|
|
|
|
|
|
|
|
|
|
|
144 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
145 |
except Exception as e:
|
146 |
return f"Error with Atla model {model_name}: {str(e)}"
|
147 |
|
|
|
317 |
api_model, final_prompt, system_prompt, max_tokens, temperature = 0.01
|
318 |
)
|
319 |
elif organization == "Atla":
|
320 |
+
response = get_atla_response(
|
321 |
+
api_model, prompt_data, system_prompt, max_tokens, temperature
|
322 |
)
|
323 |
+
# Response now contains score and critique directly
|
324 |
+
if isinstance(response, dict) and 'score' in response and 'critique' in response:
|
325 |
+
score = str(response['score'])
|
326 |
+
critique = response['critique']
|
327 |
+
return score, critique
|
328 |
+
else:
|
329 |
+
return "Error", str(response)
|
330 |
elif organization == "Cohere":
|
331 |
return get_cohere_response(
|
332 |
api_model, final_prompt, system_prompt, max_tokens, temperature
|
|
|
353 |
# Debug print
|
354 |
print(f"Raw model response: {response}")
|
355 |
|
356 |
+
# If response is already a tuple (from Atla/Salesforce), use it directly
|
357 |
+
if isinstance(response, tuple):
|
358 |
+
return response
|
359 |
+
|
360 |
# If response is already a dictionary, use it directly
|
361 |
if isinstance(response, dict):
|
362 |
return str(response.get("result", "N/A")), response.get("feedback", "N/A")
|
|
|
366 |
data = json.loads(response)
|
367 |
return str(data.get("result", "N/A")), data.get("feedback", "N/A")
|
368 |
except json.JSONDecodeError:
|
369 |
+
# If that fails, check if this is a Salesforce response
|
370 |
if "**Reasoning:**" in response or "**Result:**" in response:
|
371 |
+
# Use ATLA parser for Salesforce responses only
|
372 |
+
return salesforce_parse_model_response(response)
|
373 |
|
374 |
# Otherwise try to find JSON within the response
|
375 |
json_match = re.search(r"{.*}", response, re.DOTALL)
|
|
|
450 |
print(f"Failed to parse response: {str(e)}")
|
451 |
return "Error", f"Exception during parsing: {str(e)}"
|
452 |
|
453 |
+
def salesforce_parse_model_response(output):
|
454 |
+
"""Parse response from Salesforce model"""
|
455 |
try:
|
456 |
+
print(f"Raw Salesforce model response: {output}")
|
457 |
output = output.strip()
|
458 |
|
459 |
# Look for the Reasoning and Result sections
|
|
|
465 |
score = result_match.group(1)
|
466 |
return str(score), feedback
|
467 |
|
468 |
+
return "Error", f"Failed to parse Salesforce response format: {output}"
|
469 |
|
470 |
except Exception as e:
|
471 |
+
print(f"Failed to parse Salesforce response: {str(e)}")
|
472 |
return "Error", f"Exception during parsing: {str(e)}"
|
473 |
|
474 |
def flow_judge_parse_model_response(output):
|