File size: 9,789 Bytes
687789c
 
 
 
 
 
 
 
 
 
 
 
 
 
268c5f7
687789c
 
694bea5
 
 
 
 
 
687789c
694bea5
687789c
 
 
 
 
 
694bea5
687789c
 
 
 
 
 
268c5f7
687789c
 
3e44fe1
687789c
 
 
 
 
 
 
 
 
 
268c5f7
687789c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
268c5f7
 
687789c
 
 
 
 
 
 
 
268c5f7
 
687789c
 
 
 
 
 
268c5f7
687789c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
694bea5
 
268c5f7
 
694bea5
 
 
 
 
 
687789c
694bea5
687789c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
694bea5
687789c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
import os
import requests
import joblib
import logging
import zipfile
import pandas as pd
import numpy as np
import warnings
from flask import Flask, request, jsonify
from flask_cors import CORS

# Suppress sklearn warnings
warnings.filterwarnings('ignore', category=UserWarning, module='sklearn')

# Logging setup
logging.basicConfig(level=logging.INFO)

# Load URLs from environment
MODEL_URLS = {
    "DIABETES_MODEL": os.getenv("DIABETES_MODEL_URL"),
    "SCALER": os.getenv("SCALER_URL"),
    "MULTI_MODEL": os.getenv("MULTI_MODEL_URL"),
}

# Model ZIP filenames
MODEL_PATHS = {
    "DIABETES_MODEL": "finaliseddiabetes_model.zip",
    "SCALER": "finalisedscaler.zip",
    "MULTI_MODEL": "nodiabetes.zip",
}

# Extracted model file names
EXTRACTED_MODELS = {
    "DIABETES_MODEL": "finaliseddiabetes_model.joblib",
    "SCALER": "finalisedscaler.joblib",
    "MULTI_MODEL": "nodiabetes.joblib",
}

TMP_DIR = "/tmp"

app = Flask(__name__)
CORS(app, supports_credentials=True)

@app.route('/')
def index():
    return """
    <h1>Welcome to the Diabetes Health Predictor API 👋</h1>
    <p>This Hugging Face Space provides health risk predictions including diabetes, hypertension, stroke, and cardiovascular conditions.</p>
    <p>Use the <code>/predict</code> endpoint via POST request to get started with your health insights!</p>
    """

def download_model(url, zip_filename):
    zip_path = os.path.join(TMP_DIR, zip_filename)
    if not url:
        logging.error(f"URL for {zip_filename} is missing!")
        return False
    try:
        response = requests.get(url, allow_redirects=True)
        if response.status_code == 200:
            with open(zip_path, 'wb') as f:
                f.write(response.content)
            logging.info(f"Downloaded {zip_filename} successfully.")
            return True
        else:
            logging.error(f"Failed to download {zip_filename}. HTTP Status: {response.status_code}")
            return False
    except Exception as e:
        logging.error(f"Error downloading {zip_filename}: {e}")
        return False

def extract_if_needed(zip_filename, extracted_filename):
    zip_path = os.path.join(TMP_DIR, zip_filename)
    extracted_path = os.path.join(TMP_DIR, extracted_filename)
    if os.path.exists(extracted_path):
        logging.info(f"{extracted_filename} already exists. Skipping extraction.")
        return True
    if not os.path.exists(zip_path):
        logging.error(f"Zip file missing: {zip_path}")
        return False
    try:
        with zipfile.ZipFile(zip_path, 'r') as zip_ref:
            zip_ref.extractall(TMP_DIR)
        logging.info(f"Extracted {zip_filename} to {TMP_DIR}")
        return True
    except Exception as e:
        logging.error(f"Error extracting {zip_filename}: {e}")
        return False

def load_model(model_filename):
    model_path = os.path.join(TMP_DIR, model_filename)
    if not os.path.exists(model_path):
        logging.error(f"Model file not found: {model_path}")
        return None
    try:
        model = joblib.load(model_path)
        logging.info(f"Loaded {model_filename} successfully.")
        return model
    except Exception as e:
        logging.error(f"Error loading {model_filename}: {e}")
        return None

def initialize_models():
    models = {}
    for model_key, zip_filename in MODEL_PATHS.items():
        extracted_filename = EXTRACTED_MODELS[model_key]
        url = MODEL_URLS.get(model_key)

        zip_path = os.path.join(TMP_DIR, zip_filename)
        if not os.path.exists(zip_path):
            if not download_model(url, zip_filename):
                continue

        if not extract_if_needed(zip_filename, extracted_filename):
            continue

        models[model_key] = load_model(extracted_filename)

    return models

models = initialize_models()

FEATURE_ORDER = [
    'Pregnancies', 'Glucose', 'BloodPressure', 'Insulin',
    'BMI', 'DiabetesPedigreeFunction', 'Age'
]

def validate_input(value, input_type=float, min_value=0, max_value=None):
    try:
        value = input_type(value)
        if value < min_value:
            return None
        if max_value is not None and value > max_value:
            return None
        return value
    except (ValueError, TypeError):
        return None

def validate_blood_pressure(systolic, diastolic):
    systolic = validate_input(systolic, float, 0, 300)
    diastolic = validate_input(diastolic, float, 0, 200)
    if systolic is None or diastolic is None:
        return None, None
    return systolic, diastolic

def validate_gender(gender):
    if isinstance(gender, str) and gender.lower() in ['male', 'female']:
        return 1 if gender.lower() == 'male' else 0
    return None

def calculate_diabetes_pedigree(family_history, first_degree=0, second_degree=0):
    if not family_history:
        return 0.0
    genetic_contribution = (first_degree * 0.5) + (second_degree * 0.25)
    return min(genetic_contribution, 1.0)

def get_multi_condition_predictions(model, df):
    try:
        predictions = model.predict(df)[0]
        probs_list = model.predict_proba(df)
        return {
            'hypertension': bool(predictions[0]),
            'cardiovascular': float(probs_list[1][0][1]),
            'stroke': float(probs_list[2][0][1]),
            'diabetes': float(probs_list[3][0][1])
        }
    except Exception as e:
        logging.error(f"Error in multi-condition prediction: {str(e)}")
        return None

def get_diabetes_prediction(model, df):
    try:
        prediction = model.predict(df)[0]
        probability = float(model.predict_proba(df)[0][1] * 100)
        return 'Diabetes' if prediction else 'No Diabetes', probability
    except Exception as e:
        logging.error(f"Error in diabetes prediction: {str(e)}")
        return None, 0.0

@app.route('/health', methods=['GET'])
def health_check():
    return jsonify({'status': 'healthy', 'message': 'Service is running'})

@app.route('/predict', methods=['POST'])
def predict_health():
    try:
        data = request.get_json()
        logging.info(f"Received data: {data}")
        if not data:
            return jsonify({'status': 'error', 'error': 'Invalid JSON payload'}), 400

        gender = validate_gender(data.get('gender'))
        if gender is None:
            return jsonify({'status': 'error', 'error': 'Invalid gender value. Must be \"male\" or \"female\"'}), 400

        systolic, diastolic = validate_blood_pressure(data.get('systolic'), data.get('diastolic'))
        if systolic is None or diastolic is None:
            return jsonify({'status': 'error', 'error': 'Invalid blood pressure values'}), 400

        age = validate_input(data.get('age'), float, 0, 120)
        glucose = validate_input(data.get('glucose'), float, 0, 1000)
        bmi = validate_input(data.get('bmi'), float, 0, 100)

        if any(v is None for v in [age, glucose, bmi]):
            return jsonify({'status': 'error', 'error': 'Invalid values for age, glucose, or BMI'}), 400

        use_multi_condition = systolic < 90 or diastolic < 60

        if use_multi_condition:
            df_multi = pd.DataFrame([{
                'Age': age,
                'Gender': gender,
                'Systolic_bp': systolic,
                'Diastolic_bp': diastolic,
                'Glucose': glucose,
                'BMI': bmi
            }])

            results = get_multi_condition_predictions(models['MULTI_MODEL'], df_multi)
            if results is None:
                return jsonify({'status': 'error', 'error': 'Error in multi-condition prediction'}), 500

            return jsonify({
                'status': 'success',
                'model': 'multi-condition',
                'predictions': {
                    'hypertension': results['hypertension'],
                    'cardiovascular_risk': results['cardiovascular'],
                    'stroke_risk': results['stroke'],
                    'diabetes_risk': results['diabetes']
                }
            })

        pregnancies = validate_input(data.get('pregnancies', 0 if gender == 1 else None), float, 0, 20)
        insulin = validate_input(data.get('insulin'), float, 0, 1000)

        family_history = data.get('family_history', False)
        first_degree = validate_input(data.get('first_degree_relatives', 0), float, 0, 10)
        second_degree = validate_input(data.get('second_degree_relatives', 0), float, 0, 20)

        diabetes_pedigree = calculate_diabetes_pedigree(
            family_history,
            first_degree if first_degree is not None else 0,
            second_degree if second_degree is not None else 0
        )

        if any(v is None for v in [pregnancies, insulin]):
            return jsonify({'status': 'error', 'error': 'Invalid values for pregnancies or insulin'}), 400

        df_diabetes = pd.DataFrame([{
            'Pregnancies': pregnancies,
            'Glucose': glucose,
            'BloodPressure': systolic,
            'Insulin': insulin,
            'BMI': bmi,
            'DiabetesPedigreeFunction': diabetes_pedigree,
            'Age': age
        }])

        df_diabetes = df_diabetes[FEATURE_ORDER]
        df_scaled = models['SCALER'].transform(df_diabetes)
        prediction, probability = get_diabetes_prediction(models['DIABETES_MODEL'], df_scaled)

        return jsonify({
            'status': 'success',
            'model': 'diabetes',
            'prediction': prediction,
            'probability': probability,
            'risk_level': 'HIGH' if probability > 70 else 'MODERATE' if probability > 40 else 'LOW'
        })

    except Exception as e:
        logging.error(f"Error: {e}")
        return jsonify({'status': 'error', 'error': str(e)}), 500

if __name__ == '__main__':
    app.run(host="0.0.0.0", port=7860)