sparkleman
INIT
109a0c8
raw
history blame
18.3 kB
import os, copy, types, gc, sys, re, time, collections, asyncio
from huggingface_hub import hf_hub_download
from loguru import logger
from snowflake import SnowflakeGenerator
CompletionIdGenerator = SnowflakeGenerator(42, timestamp=1741101491595)
from pynvml import *
nvmlInit()
gpu_h = nvmlDeviceGetHandleByIndex(0)
from typing import List, Optional, Union
from pydantic import BaseModel, Field
from pydantic_settings import BaseSettings
class Config(BaseSettings, cli_parse_args=True, cli_use_class_docs_for_groups=True):
HOST: str = Field("127.0.0.1", description="Host")
PORT: int = Field(8000, description="Port")
DEBUG: bool = Field(False, description="Debug mode")
STRATEGY: str = Field("cpu", description="Stratergy")
MODEL_TITLE: str = Field("RWKV-x070-World-0.1B-v2.8-20241210-ctx4096")
DOWNLOAD_REPO_ID: str = Field("BlinkDL/rwkv-7-world")
DOWNLOAD_MODEL_DIR: Union[str, None] = Field(None, description="Model Download Dir")
MODEL_FILE_PATH: Union[str, None] = Field(None, description="Model Path")
GEN_penalty_decay: float = Field(0.996, description="Default penalty decay")
CHUNK_LEN: int = Field(
256,
description="split input into chunks to save VRAM (shorter -> slower, but saves VRAM)",
)
VOCAB: str = Field("rwkv_vocab_v20230424", description="Vocab Name")
CONFIG = Config()
import numpy as np
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True
os.environ["RWKV_V7_ON"] = "1" # enable this for rwkv-7 models
os.environ["RWKV_JIT_ON"] = "1"
os.environ["RWKV_CUDA_ON"] = (
"0" # !!! '1' to compile CUDA kernel (10x faster), requires c++ compiler & cuda libraries !!!
)
from rwkv.model import RWKV
from rwkv.utils import PIPELINE, PIPELINE_ARGS
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
from api_types import (
ChatMessage,
ChatCompletion,
ChatCompletionChunk,
Usage,
PromptTokensDetails,
ChatCompletionChoice,
ChatCompletionMessage,
)
from utils import cleanMessages, parse_think_response
logger.info(f"STRATEGY - {CONFIG.STRATEGY}")
if CONFIG.MODEL_FILE_PATH == None:
CONFIG.MODEL_FILE_PATH = hf_hub_download(
repo_id=CONFIG.DOWNLOAD_REPO_ID,
filename=f"{CONFIG.MODEL_TITLE}.pth",
local_dir=CONFIG.DOWNLOAD_MODEL_DIR,
)
logger.info(f"Load Model - {CONFIG.MODEL_FILE_PATH}")
model = RWKV(model=CONFIG.MODEL_FILE_PATH.replace(".pth", ""), strategy=CONFIG.STRATEGY)
pipeline = PIPELINE(model, CONFIG.VOCAB)
class ChatCompletionRequest(BaseModel):
model: str = Field(
default="rwkv-latest",
description="Add `:thinking` suffix to the model name to enable reasoning. Example: `rwkv-latest:thinking`",
)
messages: List[ChatMessage]
prompt: Union[str, None] = Field(default=None)
max_tokens: int = Field(default=512)
temperature: float = Field(default=1.0)
top_p: float = Field(default=0.3)
presencePenalty: float = Field(default=0.5)
countPenalty: float = Field(default=0.5)
stream: bool = Field(default=False)
state_name: str = Field(default=None)
include_usage: bool = Field(default=False)
app = FastAPI(title="RWKV OpenAI-Compatible API")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
def runPrefill(ctx: str, model_tokens: List[int], model_state):
ctx = ctx.replace("\r\n", "\n")
tokens = pipeline.encode(ctx)
tokens = [int(x) for x in tokens]
model_tokens += tokens
while len(tokens) > 0:
out, model_state = model.forward(tokens[: CONFIG.CHUNK_LEN], model_state)
tokens = tokens[CONFIG.CHUNK_LEN :]
return out, model_tokens, model_state
def generate(
request: ChatCompletionRequest,
out,
model_tokens,
model_state,
stops=["\n\n"],
max_tokens=2048,
):
args = PIPELINE_ARGS(
temperature=max(0.2, request.temperature),
top_p=request.top_p,
alpha_frequency=request.countPenalty,
alpha_presence=request.presencePenalty,
token_ban=[], # ban the generation of some tokens
token_stop=[0],
) # stop generation whenever you see any token here
occurrence = {}
out_tokens = []
out_last = 0
output_cache = collections.deque(maxlen=5)
for i in range(max_tokens):
for n in occurrence:
out[n] -= args.alpha_presence + occurrence[n] * args.alpha_frequency
out[0] -= 1e10 # disable END_OF_TEXT
token = pipeline.sample_logits(
out, temperature=args.temperature, top_p=args.top_p
)
out, model_state = model.forward([token], model_state)
model_tokens += [token]
out_tokens += [token]
for xxx in occurrence:
occurrence[xxx] *= CONFIG.GEN_penalty_decay
occurrence[token] = 1 + (occurrence[token] if token in occurrence else 0)
tmp: str = pipeline.decode(out_tokens[out_last:])
if "\ufffd" in tmp:
continue
output_cache.append(tmp)
output_cache_str = "".join(output_cache)
for stop_words in stops:
if stop_words in output_cache_str:
yield {
"content": tmp.replace(stop_words, ""),
"tokens": out_tokens[out_last:],
"finish_reason": "stop",
"state": model_state,
}
del out
gc.collect()
return
yield {
"content": tmp,
"tokens": out_tokens[out_last:],
"finish_reason": None,
}
out_last = i + 1
else:
yield {
"content": "",
"tokens": [],
"finish_reason": "length",
}
async def chatResponse(
request: ChatCompletionRequest, model_state: any, completionId: str
) -> ChatCompletion:
createTimestamp = time.time()
enableReasoning = request.model.endswith(":thinking")
prompt = (
f"{cleanMessages(request.messages)}\n\nAssistant:{' <think' if enableReasoning else ''}"
if request.prompt == None
else request.prompt.strip()
)
out, model_tokens, model_state = runPrefill(prompt, [], model_state)
prefillTime = time.time()
promptTokenCount = len(model_tokens)
fullResponse = " <think" if enableReasoning else ""
completionTokenCount = 0
finishReason = None
for chunk in generate(
request,
out,
model_tokens,
model_state,
max_tokens=(
64000
if "max_tokens" not in request.model_fields_set and enableReasoning
else request.max_tokens
),
):
fullResponse += chunk["content"]
completionTokenCount += 1
if chunk["finish_reason"]:
finishReason = chunk["finish_reason"]
await asyncio.sleep(0)
genenrateTime = time.time()
responseLog = {
"content": fullResponse,
"finish": finishReason,
"prefill_len": promptTokenCount,
"prefill_tps": round(promptTokenCount / (prefillTime - createTimestamp), 2),
"gen_len": completionTokenCount,
"gen_tps": round(completionTokenCount / (genenrateTime - prefillTime), 2),
}
logger.info(f"[RES] {completionId} - {responseLog}")
reasoning_content, content = parse_think_response(fullResponse)
response = ChatCompletion(
id=completionId,
created=int(createTimestamp),
model=request.model,
usage=Usage(
prompt_tokens=promptTokenCount,
completion_tokens=completionTokenCount,
total_tokens=promptTokenCount + completionTokenCount,
prompt_tokens_details={"cached_tokens": 0},
),
choices=[
ChatCompletionChoice(
index=0,
message=ChatCompletionMessage(
role="Assistant",
content=content,
reasoning_content=reasoning_content if reasoning_content else None,
),
logprobs=None,
finish_reason=finishReason,
)
],
)
return response
async def chatResponseStream(
request: ChatCompletionRequest, model_state: any, completionId: str
):
createTimestamp = int(time.time())
enableReasoning = request.model.endswith(":thinking")
prompt = (
f"{cleanMessages(request.messages)}\n\nAssistant:{' <think' if enableReasoning else ''}"
if request.prompt == None
else request.prompt.strip()
)
out, model_tokens, model_state = runPrefill(prompt, [], model_state)
prefillTime = time.time()
promptTokenCount = len(model_tokens)
completionTokenCount = 0
finishReason = None
response = ChatCompletionChunk(
id=completionId,
created=createTimestamp,
model=request.model,
usage=(
Usage(
prompt_tokens=promptTokenCount,
completion_tokens=completionTokenCount,
total_tokens=promptTokenCount + completionTokenCount,
prompt_tokens_details={"cached_tokens": 0},
)
if request.include_usage
else None
),
choices=[
ChatCompletionChoice(
index=0,
delta=ChatCompletionMessage(
role="Assistant",
content="",
reasoning_content="" if enableReasoning else None,
),
logprobs=None,
finish_reason=finishReason,
)
],
)
yield f"data: {response.model_dump_json()}\n\n"
buffer = []
if enableReasoning:
buffer.append(" <think")
streamConfig = {
"isChecking": False,
"fullTextCursor": 0,
"in_think": False,
"cacheStr": "",
}
for chunk in generate(
request,
out,
model_tokens,
model_state,
max_tokens=(
64000
if "max_tokens" not in request.model_fields_set and enableReasoning
else request.max_tokens
),
):
completionTokenCount += 1
chunkContent: str = chunk["content"]
buffer.append(chunkContent)
fullText = "".join(buffer)
if chunk["finish_reason"]:
finishReason = chunk["finish_reason"]
response = ChatCompletionChunk(
id=completionId,
created=createTimestamp,
model=request.model,
usage=(
Usage(
prompt_tokens=promptTokenCount,
completion_tokens=completionTokenCount,
total_tokens=promptTokenCount + completionTokenCount,
prompt_tokens_details={"cached_tokens": 0},
)
if request.include_usage
else None
),
choices=[
ChatCompletionChoice(
index=0,
delta=ChatCompletionMessage(
content=None, reasoning_content=None
),
logprobs=None,
finish_reason=finishReason,
)
],
)
markStart = fullText.find("<", streamConfig["fullTextCursor"])
if not streamConfig["isChecking"] and markStart != -1:
streamConfig["isChecking"] = True
if streamConfig["in_think"]:
response.choices[0].delta.reasoning_content = fullText[
streamConfig["fullTextCursor"] : markStart
]
else:
response.choices[0].delta.content = fullText[
streamConfig["fullTextCursor"] : markStart
]
streamConfig["cacheStr"] = ""
streamConfig["fullTextCursor"] = markStart
if streamConfig["isChecking"]:
streamConfig["cacheStr"] = fullText[streamConfig["fullTextCursor"] :]
else:
if streamConfig["in_think"]:
response.choices[0].delta.reasoning_content = chunkContent
else:
response.choices[0].delta.content = chunkContent
streamConfig["fullTextCursor"] = len(fullText)
markEnd = fullText.find(">", streamConfig["fullTextCursor"])
if streamConfig["isChecking"] and markEnd != -1:
streamConfig["isChecking"] = False
if (
not streamConfig["in_think"]
and streamConfig["cacheStr"].find("<think>") != -1
):
streamConfig["in_think"] = True
response.choices[0].delta.reasoning_content = (
response.choices[0].delta.reasoning_content
if response.choices[0].delta.reasoning_content != None
else "" + streamConfig["cacheStr"].replace("<think>", "")
)
elif (
streamConfig["in_think"]
and streamConfig["cacheStr"].find("</think>") != -1
):
streamConfig["in_think"] = False
response.choices[0].delta.content = (
response.choices[0].delta.content
if response.choices[0].delta.content != None
else "" + streamConfig["cacheStr"].replace("</think>", "")
)
else:
if streamConfig["in_think"]:
response.choices[0].delta.reasoning_content = (
response.choices[0].delta.reasoning_content
if response.choices[0].delta.reasoning_content != None
else "" + streamConfig["cacheStr"]
)
else:
response.choices[0].delta.content = (
response.choices[0].delta.content
if response.choices[0].delta.content != None
else "" + streamConfig["cacheStr"]
)
streamConfig["fullTextCursor"] = len(fullText)
if (
response.choices[0].delta.content != None
or response.choices[0].delta.reasoning_content != None
):
yield f"data: {response.model_dump_json()}\n\n"
await asyncio.sleep(0)
del streamConfig
else:
for chunk in generate(request, out, model_tokens, model_state):
completionTokenCount += 1
buffer.append(chunk["content"])
if chunk["finish_reason"]:
finishReason = chunk["finish_reason"]
response = ChatCompletionChunk(
id=completionId,
created=createTimestamp,
model=request.model,
usage=(
Usage(
prompt_tokens=promptTokenCount,
completion_tokens=completionTokenCount,
total_tokens=promptTokenCount + completionTokenCount,
prompt_tokens_details={"cached_tokens": 0},
)
if request.include_usage
else None
),
choices=[
ChatCompletionChoice(
index=0,
delta=ChatCompletionMessage(content=chunk["content"]),
logprobs=None,
finish_reason=finishReason,
)
],
)
yield f"data: {response.model_dump_json()}\n\n"
await asyncio.sleep(0)
genenrateTime = time.time()
responseLog = {
"content": "".join(buffer),
"finish": finishReason,
"prefill_len": promptTokenCount,
"prefill_tps": round(promptTokenCount / (prefillTime - createTimestamp), 2),
"gen_len": completionTokenCount,
"gen_tps": round(completionTokenCount / (genenrateTime - prefillTime), 2),
}
logger.info(f"[RES] {completionId} - {responseLog}")
del buffer
yield "data: [DONE]\n\n"
@app.post("/api/v1/chat/completions")
async def chat_completions(request: ChatCompletionRequest):
completionId = str(next(CompletionIdGenerator))
logger.info(f"[REQ] {completionId} - {request.model_dump()}")
def chatResponseStreamDisconnect():
gpu_info = nvmlDeviceGetMemoryInfo(gpu_h)
logger.info(
f"[STATUS] vram {gpu_info.total} used {gpu_info.used} free {gpu_info.free}"
)
model_state = None
if request.stream:
r = StreamingResponse(
chatResponseStream(request, model_state, completionId),
media_type="text/event-stream",
background=chatResponseStreamDisconnect,
)
else:
r = await chatResponse(request, model_state, completionId)
return r
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host=CONFIG.HOST, port=CONFIG.PORT)