Spaces:
Runtime error
Runtime error
Delete main.py
Browse files
main.py
DELETED
@@ -1,56 +0,0 @@
|
|
1 |
-
from fastapi import FastAPI, Request
|
2 |
-
from transformers import AutoTokenizer, AutoModel
|
3 |
-
import uvicorn, json, datetime
|
4 |
-
import torch
|
5 |
-
|
6 |
-
DEVICE = "cuda"
|
7 |
-
DEVICE_ID = "0"
|
8 |
-
CUDA_DEVICE = f"{DEVICE}:{DEVICE_ID}" if DEVICE_ID else DEVICE
|
9 |
-
|
10 |
-
|
11 |
-
def torch_gc():
|
12 |
-
if torch.cuda.is_available():
|
13 |
-
with torch.cuda.device(CUDA_DEVICE):
|
14 |
-
torch.cuda.empty_cache()
|
15 |
-
torch.cuda.ipc_collect()
|
16 |
-
|
17 |
-
|
18 |
-
app = FastAPI()
|
19 |
-
|
20 |
-
|
21 |
-
@app.post("/")
|
22 |
-
async def create_item(request: Request):
|
23 |
-
global model, tokenizer
|
24 |
-
json_post_raw = await request.json()
|
25 |
-
json_post = json.dumps(json_post_raw)
|
26 |
-
json_post_list = json.loads(json_post)
|
27 |
-
prompt = json_post_list.get('prompt')
|
28 |
-
history = json_post_list.get('history')
|
29 |
-
max_length = json_post_list.get('max_length')
|
30 |
-
top_p = json_post_list.get('top_p')
|
31 |
-
temperature = json_post_list.get('temperature')
|
32 |
-
response, history = model.chat(tokenizer,
|
33 |
-
prompt,
|
34 |
-
history=history,
|
35 |
-
max_length=max_length if max_length else 2048,
|
36 |
-
top_p=top_p if top_p else 0.7,
|
37 |
-
temperature=temperature if temperature else 0.95)
|
38 |
-
now = datetime.datetime.now()
|
39 |
-
time = now.strftime("%Y-%m-%d %H:%M:%S")
|
40 |
-
answer = {
|
41 |
-
"response": response,
|
42 |
-
"history": history,
|
43 |
-
"status": 200,
|
44 |
-
"time": time
|
45 |
-
}
|
46 |
-
log = "[" + time + "] " + '", prompt:"' + prompt + '", response:"' + repr(response) + '"'
|
47 |
-
print(log)
|
48 |
-
torch_gc()
|
49 |
-
return answer
|
50 |
-
|
51 |
-
|
52 |
-
if __name__ == '__main__':
|
53 |
-
tokenizer = AutoTokenizer.from_pretrained("fb700/chatglm-fitness-RLHF", trust_remote_code=True)
|
54 |
-
model = AutoModel.from_pretrained("fb700/chatglm-fitness-RLHF", trust_remote_code=True).half().quantize(8).cuda()
|
55 |
-
model.eval()
|
56 |
-
uvicorn.run(app, host='0.0.0.0', port=8000, workers=5)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|