from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel import torch from transformers import pipeline from duckduckgo_search import DDGS from typing import Optional, List, Dict, Any import logging # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Initialize FastAPI app app = FastAPI( title="Open Source Chat API", description="A fully open source alternative to HF API using local models", version="1.0.0" ) # Configure CORS for Android app access app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["POST", "GET"], allow_headers=["*"], ) # Request/Response models class ChatRequest(BaseModel): prompt: str max_new_tokens: int = 500 use_search: bool = False temperature: float = 0.7 class ChatResponse(BaseModel): response: str search_results: Optional[List[Dict[str, Any]]] = None # Contains web search results if use_search is enabled class SearchRequest(BaseModel): query: str max_results: int = 5 qa_pipeline = None # Function to load the local language model def load_model(): """Load the local language model""" global qa_pipeline try: # Check if GPU is available device = "cuda" if torch.cuda.is_available() else "cpu" logger.info(f"Using device: {device}") qa_pipeline = pipeline( "question-answering", model="distilbert-base-uncased-distilled-squad", device=0 if device == "cuda" else -1 ) logger.info("Model loaded successfully!") except Exception as e: logger.error(f"Error loading model: {e}") def search_web(query: str, max_results: int = 5) -> List[Dict[str, Any]]: """Search the web using DuckDuckGo""" try: ddgs = DDGS() results = [] for result in ddgs.text(query,safesearch='off',max_results=max_results): results.append({ "title": result.get("title", ""), "body": result.get("body", ""), "href": result.get("href", "") }) return results except Exception as e: logger.error(f"Search error: {e}") return [] def generate_response(prompt: str,search_context: str) -> str: """Generate response using local model""" try: if qa_pipeline is None: return "Model not loaded properly. Please try again." # Validate that qa_pipeline is a question-answering pipeline if not hasattr(qa_pipeline, "task") or qa_pipeline.task != "question-answering": return "Invalid pipeline type. Expected a question-answering pipeline." result = qa_pipeline(question=prompt, context=search_context)['answer'] return result except Exception as e: logger.error(f"Generation error: {e}") return f"Sorry, I encountered an error: {str(e)}" @app.on_event("startup") async def startup_event(): """Load model on startup""" load_model() @app.get("/") async def root(): """Health check endpoint""" return { "message": "Open Source Chat API is running!", "model_loaded": qa_pipeline is not None, "endpoints": { "chat": "/chat", "search": "/search", "docs": "/docs" } } @app.post("/chat", response_model=ChatResponse) async def chat(request: ChatRequest): """Main chat endpoint""" try: search_results = None # Perform web search if requested if request.use_search: search_results = search_web(request.prompt) search_context = None # Enhance prompt with search results if search_results: num_results = min(5, len(search_results)) search_context = "\n".join([ f"- {result['title']}: {result['body'][:200]}..." for result in search_results[:num_results] ]) else: search_context = request.prompt if search_context: # Generate response response = generate_response( request.prompt, search_context ) else: response = "No context available to generate a response." return ChatResponse( response=response, search_results=search_results if request.use_search else None ) except Exception as e: logger.error(f"Chat endpoint error: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.post("/search") async def search(request: SearchRequest): """Web search endpoint""" try: results = search_web(request.query, request.max_results) return {"results": results} except Exception as e: logger.error(f"Search endpoint error: {e}") raise HTTPException(status_code=500, detail=str(e)) if __name__ == "__main__": import uvicorn import os host = os.getenv("HOST", "0.0.0.0") port = int(os.getenv("PORT", 7860)) # Changed from 8000 to 7860 uvicorn.run(app, host=host, port=port)