Spaces:
Runtime error
Runtime error
from fastapi import FastAPI | |
from pydantic import BaseModel | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
# Initialisera modellen och tokenizern | |
model_name = "AI-Sweden-Models/gpt-sw3-126m-instruct" | |
device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForCausalLM.from_pretrained(model_name) | |
model.to(device) | |
model.eval() | |
# FastAPI-applikationen | |
app = FastAPI() | |
class UserInput(BaseModel): | |
prompt: str | |
async def generate_response(user_input: UserInput): | |
prompt = f"<|endoftext|><s>\nUser:\n{user_input.prompt}\n<s>\nBot:" | |
input_ids = tokenizer(prompt, return_tensors="pt")["input_ids"].to(device) | |
generated_token_ids = model.generate( | |
inputs=input_ids, | |
max_new_tokens=100, | |
do_sample=True, | |
temperature=0.6, | |
top_p=1 | |
)[0] | |
generated_text = tokenizer.decode(generated_token_ids[len(input_ids[0]):-1], skip_special_tokens=True) | |
return {"response": generated_text.strip()} | |