benediktstroebl's picture
Added MLAgentBench
22fef14
raw
history blame
8.06 kB
import json
from pathlib import Path
import sqlite3
import pickle
from functools import lru_cache
import threading
import pandas as pd
class TracePreprocessor:
def __init__(self, db_path='preprocessed_traces.db'):
self.db_path = db_path
self.local = threading.local()
self.create_tables()
def get_conn(self):
if not hasattr(self.local, 'conn'):
self.local.conn = sqlite3.connect(self.db_path)
return self.local.conn
def create_tables(self):
with self.get_conn() as conn:
conn.execute('''
CREATE TABLE IF NOT EXISTS preprocessed_traces (
benchmark_name TEXT,
agent_name TEXT,
raw_logging_results BLOB,
PRIMARY KEY (benchmark_name, agent_name)
)
''')
conn.execute('''
CREATE TABLE IF NOT EXISTS failure_reports (
benchmark_name TEXT,
agent_name TEXT,
failure_report BLOB,
PRIMARY KEY (benchmark_name, agent_name)
)
''')
conn.execute('''
CREATE TABLE IF NOT EXISTS parsed_results (
benchmark_name TEXT,
agent_name TEXT,
date TEXT,
total_cost REAL,
accuracy REAL,
precision REAL,
recall REAL,
f1_score REAL,
auc REAL,
overall_score REAL,
vectorization_score REAL,
fathomnet_score REAL,
feedback_score REAL,
house_price_score REAL,
spaceship_titanic_score REAL,
amp_parkinsons_disease_progression_prediction_score REAL,
cifar10_score REAL,
imdb_score REAL,
PRIMARY KEY (benchmark_name, agent_name)
)
''')
def preprocess_traces(self, processed_dir="evals_live"):
processed_dir = Path(processed_dir)
for file in processed_dir.glob('*.json'):
with open(file, 'r') as f:
data = json.load(f)
agent_name = data['config']['agent_name']
benchmark_name = data['config']['benchmark_name']
try:
raw_logging_results = pickle.dumps(data['raw_logging_results'])
with self.get_conn() as conn:
conn.execute('''
INSERT OR REPLACE INTO preprocessed_traces
(benchmark_name, agent_name, raw_logging_results)
VALUES (?, ?, ?)
''', (benchmark_name, agent_name, raw_logging_results))
except Exception as e:
print(f"Error preprocessing raw_logging_results in {file}: {e}")
try:
failure_report = pickle.dumps(data['failure_report'])
with self.get_conn() as conn:
conn.execute('''
INSERT OR REPLACE INTO failure_reports
(benchmark_name, agent_name, failure_report)
VALUES (?, ?, ?)
''', (benchmark_name, agent_name, failure_report))
except Exception as e:
print(f"Error preprocessing failure_report in {file}: {e}")
try:
config = data['config']
results = data['results']
with self.get_conn() as conn:
conn.execute('''
INSERT OR REPLACE INTO parsed_results
(benchmark_name, agent_name, date, total_cost, accuracy, precision, recall, f1_score, auc, overall_score, vectorization_score, fathomnet_score, feedback_score, house_price_score, spaceship_titanic_score, amp_parkinsons_disease_progression_prediction_score, cifar10_score, imdb_score)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
''', (
benchmark_name,
agent_name,
config['date'],
results.get('total_cost'),
results.get('accuracy'),
results.get('precision'),
results.get('recall'),
results.get('f1_score'),
results.get('auc'),
results.get('overall_score'),
results.get('vectorization_score'),
results.get('fathomnet_score'),
results.get('feedback_score'),
results.get('house-price_score'),
results.get('spaceship-titanic_score'),
results.get('amp-parkinsons-disease-progression-prediction_score'),
results.get('cifar10_score'),
results.get('imdb_score')
))
except Exception as e:
print(f"Error preprocessing parsed results in {file}: {e}")
@lru_cache(maxsize=100)
def get_analyzed_traces(self, agent_name, benchmark_name):
with self.get_conn() as conn:
cursor = conn.cursor()
cursor.execute('''
SELECT raw_logging_results FROM preprocessed_traces
WHERE benchmark_name = ? AND agent_name = ?
''', (benchmark_name, agent_name))
result = cursor.fetchone()
if result:
return pickle.loads(result[0])
return None
@lru_cache(maxsize=100)
def get_failure_report(self, agent_name, benchmark_name):
with self.get_conn() as conn:
cursor = conn.cursor()
cursor.execute('''
SELECT failure_report FROM failure_reports
WHERE benchmark_name = ? AND agent_name = ?
''', (benchmark_name, agent_name))
result = cursor.fetchone()
if result:
return pickle.loads(result[0])
return None
def get_parsed_results(self, benchmark_name):
with self.get_conn() as conn:
query = '''
SELECT * FROM parsed_results
WHERE benchmark_name = ?
ORDER BY accuracy DESC
'''
df = pd.read_sql_query(query, conn, params=(benchmark_name,))
# Round float columns to 3 decimal places
float_columns = ['total_cost', 'accuracy', 'precision', 'recall', 'f1_score', 'auc', 'overall_score', 'vectorization_score', 'fathomnet_score', 'feedback_score', 'house-price_score', 'spaceship-titanic_score', 'amp-parkinsons-disease-progression-prediction_score', 'cifar10_score', 'imdb_score']
for column in float_columns:
if column in df.columns:
df[column] = df[column].round(3)
# Rename columns
df = df.rename(columns={
'agent_name': 'Agent Name',
'date': 'Date',
'total_cost': 'Total Cost',
'accuracy': 'Accuracy',
'precision': 'Precision',
'recall': 'Recall',
'f1_score': 'F1 Score',
'auc': 'AUC',
'overall_score': 'Overall Score',
'vectorization_score': 'Vectorization Score',
'fathomnet_score': 'Fathomnet Score',
'feedback_score': 'Feedback Score',
'house_price_score': 'House Price Score',
'spaceship_titanic_score': 'Spaceship Titanic Score',
'amp_parkinsons_disease_progression_prediction_score': 'AMP Parkinsons Disease Progression Prediction Score',
'cifar10_score': 'CIFAR10 Score',
'imdb_score': 'IMDB Score'
})
return df
if __name__ == '__main__':
preprocessor = TracePreprocessor()
preprocessor.preprocess_traces()