WhotookNima commited on
Commit
84d64bf
·
verified ·
1 Parent(s): d7bd9f2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -19
app.py CHANGED
@@ -1,27 +1,35 @@
1
- from fastapi import FastAPI, Request
2
  from pydantic import BaseModel
3
- from transformers import AutoTokenizer, AutoModelForCausalLM
4
  import torch
 
 
 
 
 
5
 
 
 
 
 
 
 
6
  app = FastAPI()
7
 
8
- # Ladda modellen
9
- model_id = "AI-Sweden/gpt-sw3-126m"
10
- tokenizer = AutoTokenizer.from_pretrained(model_id)
11
- model = AutoModelForCausalLM.from_pretrained(model_id)
12
 
13
- # Om du kör på CPU – lägg till detta
14
- device = torch.device("cpu")
15
- model.to(device)
 
16
 
17
- # Input-modell
18
- class Prompt(BaseModel):
19
- text: str
20
- max_new_tokens: int = 50
 
 
 
21
 
22
- @app.post("/generate")
23
- async def generate_text(prompt: Prompt):
24
- inputs = tokenizer(prompt.text, return_tensors="pt").to(device)
25
- outputs = model.generate(**inputs, max_new_tokens=prompt.max_new_tokens)
26
- generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
27
- return {"response": generated}
 
1
+ from fastapi import FastAPI
2
  from pydantic import BaseModel
 
3
  import torch
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM
5
+
6
+ # Initialisera modellen och tokenizern
7
+ model_name = "AI-Sweden-Models/gpt-sw3-126m-instruct"
8
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
9
 
10
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
11
+ model = AutoModelForCausalLM.from_pretrained(model_name)
12
+ model.to(device)
13
+ model.eval()
14
+
15
+ # FastAPI-applikationen
16
  app = FastAPI()
17
 
18
+ class UserInput(BaseModel):
19
+ prompt: str
 
 
20
 
21
+ @app.post("/generate/")
22
+ async def generate_response(user_input: UserInput):
23
+ prompt = f"<|endoftext|><s>\nUser:\n{user_input.prompt}\n<s>\nBot:"
24
+ input_ids = tokenizer(prompt, return_tensors="pt")["input_ids"].to(device)
25
 
26
+ generated_token_ids = model.generate(
27
+ inputs=input_ids,
28
+ max_new_tokens=100,
29
+ do_sample=True,
30
+ temperature=0.6,
31
+ top_p=1
32
+ )[0]
33
 
34
+ generated_text = tokenizer.decode(generated_token_ids[len(input_ids[0]):-1], skip_special_tokens=True)
35
+ return {"response": generated_text.strip()}