lolback / app.py
roshcheeku's picture
Update app.py
694bea5 verified
import os
import requests
import joblib
import logging
import zipfile
import pandas as pd
import numpy as np
import warnings
from flask import Flask, request, jsonify
from flask_cors import CORS
# Suppress sklearn warnings
warnings.filterwarnings('ignore', category=UserWarning, module='sklearn')
# Logging setup
logging.basicConfig(level=logging.INFO)
# Load URLs from environment
MODEL_URLS = {
"DIABETES_MODEL": os.getenv("DIABETES_MODEL_URL"),
"SCALER": os.getenv("SCALER_URL"),
"MULTI_MODEL": os.getenv("MULTI_MODEL_URL"),
}
# Model ZIP filenames
MODEL_PATHS = {
"DIABETES_MODEL": "finaliseddiabetes_model.zip",
"SCALER": "finalisedscaler.zip",
"MULTI_MODEL": "nodiabetes.zip",
}
# Extracted model file names
EXTRACTED_MODELS = {
"DIABETES_MODEL": "finaliseddiabetes_model.joblib",
"SCALER": "finalisedscaler.joblib",
"MULTI_MODEL": "nodiabetes.joblib",
}
TMP_DIR = "/tmp"
app = Flask(__name__)
CORS(app, supports_credentials=True)
@app.route('/')
def index():
return """
<h1>Welcome to the Diabetes Health Predictor API 👋</h1>
<p>This Hugging Face Space provides health risk predictions including diabetes, hypertension, stroke, and cardiovascular conditions.</p>
<p>Use the <code>/predict</code> endpoint via POST request to get started with your health insights!</p>
"""
def download_model(url, zip_filename):
zip_path = os.path.join(TMP_DIR, zip_filename)
if not url:
logging.error(f"URL for {zip_filename} is missing!")
return False
try:
response = requests.get(url, allow_redirects=True)
if response.status_code == 200:
with open(zip_path, 'wb') as f:
f.write(response.content)
logging.info(f"Downloaded {zip_filename} successfully.")
return True
else:
logging.error(f"Failed to download {zip_filename}. HTTP Status: {response.status_code}")
return False
except Exception as e:
logging.error(f"Error downloading {zip_filename}: {e}")
return False
def extract_if_needed(zip_filename, extracted_filename):
zip_path = os.path.join(TMP_DIR, zip_filename)
extracted_path = os.path.join(TMP_DIR, extracted_filename)
if os.path.exists(extracted_path):
logging.info(f"{extracted_filename} already exists. Skipping extraction.")
return True
if not os.path.exists(zip_path):
logging.error(f"Zip file missing: {zip_path}")
return False
try:
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
zip_ref.extractall(TMP_DIR)
logging.info(f"Extracted {zip_filename} to {TMP_DIR}")
return True
except Exception as e:
logging.error(f"Error extracting {zip_filename}: {e}")
return False
def load_model(model_filename):
model_path = os.path.join(TMP_DIR, model_filename)
if not os.path.exists(model_path):
logging.error(f"Model file not found: {model_path}")
return None
try:
model = joblib.load(model_path)
logging.info(f"Loaded {model_filename} successfully.")
return model
except Exception as e:
logging.error(f"Error loading {model_filename}: {e}")
return None
def initialize_models():
models = {}
for model_key, zip_filename in MODEL_PATHS.items():
extracted_filename = EXTRACTED_MODELS[model_key]
url = MODEL_URLS.get(model_key)
zip_path = os.path.join(TMP_DIR, zip_filename)
if not os.path.exists(zip_path):
if not download_model(url, zip_filename):
continue
if not extract_if_needed(zip_filename, extracted_filename):
continue
models[model_key] = load_model(extracted_filename)
return models
models = initialize_models()
FEATURE_ORDER = [
'Pregnancies', 'Glucose', 'BloodPressure', 'Insulin',
'BMI', 'DiabetesPedigreeFunction', 'Age'
]
def validate_input(value, input_type=float, min_value=0, max_value=None):
try:
value = input_type(value)
if value < min_value:
return None
if max_value is not None and value > max_value:
return None
return value
except (ValueError, TypeError):
return None
def validate_blood_pressure(systolic, diastolic):
systolic = validate_input(systolic, float, 0, 300)
diastolic = validate_input(diastolic, float, 0, 200)
if systolic is None or diastolic is None:
return None, None
return systolic, diastolic
def validate_gender(gender):
if isinstance(gender, str) and gender.lower() in ['male', 'female']:
return 1 if gender.lower() == 'male' else 0
return None
def calculate_diabetes_pedigree(family_history, first_degree=0, second_degree=0):
if not family_history:
return 0.0
genetic_contribution = (first_degree * 0.5) + (second_degree * 0.25)
return min(genetic_contribution, 1.0)
def get_multi_condition_predictions(model, df):
try:
predictions = model.predict(df)[0]
probs_list = model.predict_proba(df)
return {
'hypertension': bool(predictions[0]),
'cardiovascular': float(probs_list[1][0][1]),
'stroke': float(probs_list[2][0][1]),
'diabetes': float(probs_list[3][0][1])
}
except Exception as e:
logging.error(f"Error in multi-condition prediction: {str(e)}")
return None
def get_diabetes_prediction(model, df):
try:
prediction = model.predict(df)[0]
probability = float(model.predict_proba(df)[0][1] * 100)
return 'Diabetes' if prediction else 'No Diabetes', probability
except Exception as e:
logging.error(f"Error in diabetes prediction: {str(e)}")
return None, 0.0
@app.route('/health', methods=['GET'])
def health_check():
return jsonify({'status': 'healthy', 'message': 'Service is running'})
@app.route('/predict', methods=['POST'])
def predict_health():
try:
data = request.get_json()
logging.info(f"Received data: {data}")
if not data:
return jsonify({'status': 'error', 'error': 'Invalid JSON payload'}), 400
gender = validate_gender(data.get('gender'))
if gender is None:
return jsonify({'status': 'error', 'error': 'Invalid gender value. Must be \"male\" or \"female\"'}), 400
systolic, diastolic = validate_blood_pressure(data.get('systolic'), data.get('diastolic'))
if systolic is None or diastolic is None:
return jsonify({'status': 'error', 'error': 'Invalid blood pressure values'}), 400
age = validate_input(data.get('age'), float, 0, 120)
glucose = validate_input(data.get('glucose'), float, 0, 1000)
bmi = validate_input(data.get('bmi'), float, 0, 100)
if any(v is None for v in [age, glucose, bmi]):
return jsonify({'status': 'error', 'error': 'Invalid values for age, glucose, or BMI'}), 400
use_multi_condition = systolic < 90 or diastolic < 60
if use_multi_condition:
df_multi = pd.DataFrame([{
'Age': age,
'Gender': gender,
'Systolic_bp': systolic,
'Diastolic_bp': diastolic,
'Glucose': glucose,
'BMI': bmi
}])
results = get_multi_condition_predictions(models['MULTI_MODEL'], df_multi)
if results is None:
return jsonify({'status': 'error', 'error': 'Error in multi-condition prediction'}), 500
return jsonify({
'status': 'success',
'model': 'multi-condition',
'predictions': {
'hypertension': results['hypertension'],
'cardiovascular_risk': results['cardiovascular'],
'stroke_risk': results['stroke'],
'diabetes_risk': results['diabetes']
}
})
pregnancies = validate_input(data.get('pregnancies', 0 if gender == 1 else None), float, 0, 20)
insulin = validate_input(data.get('insulin'), float, 0, 1000)
family_history = data.get('family_history', False)
first_degree = validate_input(data.get('first_degree_relatives', 0), float, 0, 10)
second_degree = validate_input(data.get('second_degree_relatives', 0), float, 0, 20)
diabetes_pedigree = calculate_diabetes_pedigree(
family_history,
first_degree if first_degree is not None else 0,
second_degree if second_degree is not None else 0
)
if any(v is None for v in [pregnancies, insulin]):
return jsonify({'status': 'error', 'error': 'Invalid values for pregnancies or insulin'}), 400
df_diabetes = pd.DataFrame([{
'Pregnancies': pregnancies,
'Glucose': glucose,
'BloodPressure': systolic,
'Insulin': insulin,
'BMI': bmi,
'DiabetesPedigreeFunction': diabetes_pedigree,
'Age': age
}])
df_diabetes = df_diabetes[FEATURE_ORDER]
df_scaled = models['SCALER'].transform(df_diabetes)
prediction, probability = get_diabetes_prediction(models['DIABETES_MODEL'], df_scaled)
return jsonify({
'status': 'success',
'model': 'diabetes',
'prediction': prediction,
'probability': probability,
'risk_level': 'HIGH' if probability > 70 else 'MODERATE' if probability > 40 else 'LOW'
})
except Exception as e:
logging.error(f"Error: {e}")
return jsonify({'status': 'error', 'error': str(e)}), 500
if __name__ == '__main__':
app.run(host="0.0.0.0", port=7860)