Deep8591 commited on
Commit
d5fa99a
·
verified ·
1 Parent(s): f1f0496

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +102 -103
app.py CHANGED
@@ -1,103 +1,102 @@
1
- import torch
2
- from loguru import logger
3
- from pydantic import BaseModel, Field
4
- from fastapi import FastAPI, HTTPException
5
- from fastapi.middleware.cors import CORSMiddleware
6
- from sentence_transformers import CrossEncoder
7
- from typing import List, Optional
8
-
9
- # Initialize FastAPI app with documentation metadata
10
- app = FastAPI(
11
- title="Document Reranker API",
12
- description="An API for reranking documents using a CrossEncoder model.",
13
- version="1.0",
14
- docs_url="/docs", # Swagger UI
15
- redoc_url="/redoc", # ReDoc UI
16
- )
17
-
18
- # Enable CORS (optional but useful for frontend integration)
19
- app.add_middleware(
20
- CORSMiddleware,
21
- allow_origins=["*"], # Allow all origins (modify as needed)
22
- allow_credentials=True,
23
- allow_methods=["*"],
24
- allow_headers=["*"],
25
- )
26
-
27
- # Device selection
28
- DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
29
- logger.warning(
30
- f"Using device: {DEVICE} ({'GPU: ' + torch.cuda.get_device_name(0) if DEVICE.type == 'cuda' else 'Running on CPU'})"
31
- )
32
-
33
- # Load the model at startup to avoid reloading for each request
34
- model = CrossEncoder(
35
- "jinaai/jina-reranker-v1-turbo-en",
36
- trust_remote_code=True,
37
- device=DEVICE,
38
- cache_dir="models",
39
- )
40
-
41
-
42
- class RerankerRequest(BaseModel):
43
- query: str = Field(..., description="The search query string")
44
- documents: List[str] = Field(..., description="List of documents to rerank")
45
- return_documents: bool = Field(
46
- True, description="Whether to return document content in results"
47
- )
48
- top_k: int = Field(3, description="Number of top results to return")
49
-
50
-
51
- class RankedResult(BaseModel):
52
- score: float
53
- index: int
54
- document: Optional[str] = None
55
-
56
-
57
- class RerankerResponse(BaseModel):
58
- results: List[RankedResult]
59
-
60
-
61
- @app.post("/rerank", response_model=RerankerResponse, tags=["Reranker"])
62
- async def rerank_documents(request: RerankerRequest):
63
- """
64
- Reranks the given list of documents based on their relevance to the query.
65
-
66
- - **query**: The input query string.
67
- - **documents**: A list of documents to be reranked.
68
- - **return_documents**: Whether to include document content in results.
69
- - **top_k**: Number of top-ranked documents to return.
70
-
71
- Returns:
72
- - A list of ranked documents with scores and indexes.
73
- """
74
- try:
75
- # Call the model's rank method with the provided parameters
76
- results = model.rank(
77
- request.query,
78
- request.documents,
79
- return_documents=request.return_documents,
80
- top_k=request.top_k,
81
- )
82
-
83
- # Format the results based on the model's output
84
- formatted_results = [
85
- RankedResult(
86
- score=result["score"],
87
- index=result["corpus_id"],
88
- document=result["text"] if request.return_documents else None,
89
- )
90
- for result in results
91
- ]
92
-
93
- return RerankerResponse(results=formatted_results)
94
-
95
- except Exception as e:
96
- raise HTTPException(status_code=500, detail=f"Error in reranking: {str(e)}")
97
-
98
-
99
- # Run the FastAPI app with Uvicorn
100
- if __name__ == "__main__":
101
- import uvicorn
102
-
103
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
+ import torch
2
+ from loguru import logger
3
+ from pydantic import BaseModel, Field
4
+ from fastapi import FastAPI, HTTPException
5
+ from fastapi.middleware.cors import CORSMiddleware
6
+ from sentence_transformers import CrossEncoder
7
+ from typing import List, Optional
8
+
9
+ # Initialize FastAPI app with documentation metadata
10
+ app = FastAPI(
11
+ title="Document Reranker API",
12
+ description="An API for reranking documents using a CrossEncoder model.",
13
+ version="1.0",
14
+ docs_url="/docs", # Swagger UI
15
+ redoc_url="/redoc", # ReDoc UI
16
+ )
17
+
18
+ # Enable CORS (optional but useful for frontend integration)
19
+ app.add_middleware(
20
+ CORSMiddleware,
21
+ allow_origins=["*"], # Allow all origins (modify as needed)
22
+ allow_credentials=True,
23
+ allow_methods=["*"],
24
+ allow_headers=["*"],
25
+ )
26
+
27
+ # Device selection
28
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
29
+ logger.warning(
30
+ f"Using device: {DEVICE} ({'GPU: ' + torch.cuda.get_device_name(0) if DEVICE.type == 'cuda' else 'Running on CPU'})"
31
+ )
32
+
33
+ # Load the model at startup to avoid reloading for each request
34
+ model = CrossEncoder(
35
+ "jinaai/jina-reranker-v1-turbo-en",
36
+ trust_remote_code=True,
37
+ device=DEVICE
38
+ )
39
+
40
+
41
+ class RerankerRequest(BaseModel):
42
+ query: str = Field(..., description="The search query string")
43
+ documents: List[str] = Field(..., description="List of documents to rerank")
44
+ return_documents: bool = Field(
45
+ True, description="Whether to return document content in results"
46
+ )
47
+ top_k: int = Field(3, description="Number of top results to return")
48
+
49
+
50
+ class RankedResult(BaseModel):
51
+ score: float
52
+ index: int
53
+ document: Optional[str] = None
54
+
55
+
56
+ class RerankerResponse(BaseModel):
57
+ results: List[RankedResult]
58
+
59
+
60
+ @app.post("/rerank", response_model=RerankerResponse, tags=["Reranker"])
61
+ async def rerank_documents(request: RerankerRequest):
62
+ """
63
+ Reranks the given list of documents based on their relevance to the query.
64
+
65
+ - **query**: The input query string.
66
+ - **documents**: A list of documents to be reranked.
67
+ - **return_documents**: Whether to include document content in results.
68
+ - **top_k**: Number of top-ranked documents to return.
69
+
70
+ Returns:
71
+ - A list of ranked documents with scores and indexes.
72
+ """
73
+ try:
74
+ # Call the model's rank method with the provided parameters
75
+ results = model.rank(
76
+ request.query,
77
+ request.documents,
78
+ return_documents=request.return_documents,
79
+ top_k=request.top_k,
80
+ )
81
+
82
+ # Format the results based on the model's output
83
+ formatted_results = [
84
+ RankedResult(
85
+ score=result["score"],
86
+ index=result["corpus_id"],
87
+ document=result["text"] if request.return_documents else None,
88
+ )
89
+ for result in results
90
+ ]
91
+
92
+ return RerankerResponse(results=formatted_results)
93
+
94
+ except Exception as e:
95
+ raise HTTPException(status_code=500, detail=f"Error in reranking: {str(e)}")
96
+
97
+
98
+ # Run the FastAPI app with Uvicorn
99
+ if __name__ == "__main__":
100
+ import uvicorn
101
+
102
+ uvicorn.run(app, host="0.0.0.0", port=7860)