import os, copy, types, gc, sys, re, time, collections, asyncio from huggingface_hub import hf_hub_download from loguru import logger from snowflake import SnowflakeGenerator CompletionIdGenerator = SnowflakeGenerator(42, timestamp=1741101491595) from pynvml import * nvmlInit() gpu_h = nvmlDeviceGetHandleByIndex(0) from typing import List, Optional, Union from pydantic import BaseModel, Field from pydantic_settings import BaseSettings class Config(BaseSettings, cli_parse_args=True, cli_use_class_docs_for_groups=True): HOST: str = Field("127.0.0.1", description="Host") PORT: int = Field(8000, description="Port") DEBUG: bool = Field(False, description="Debug mode") STRATEGY: str = Field("cpu", description="Stratergy") MODEL_TITLE: str = Field("RWKV-x070-World-0.1B-v2.8-20241210-ctx4096") DOWNLOAD_REPO_ID: str = Field("BlinkDL/rwkv-7-world") DOWNLOAD_MODEL_DIR: Union[str, None] = Field(None, description="Model Download Dir") MODEL_FILE_PATH: Union[str, None] = Field(None, description="Model Path") GEN_penalty_decay: float = Field(0.996, description="Default penalty decay") CHUNK_LEN: int = Field( 256, description="split input into chunks to save VRAM (shorter -> slower, but saves VRAM)", ) VOCAB: str = Field("rwkv_vocab_v20230424", description="Vocab Name") CONFIG = Config() import numpy as np import torch device = torch.device("cuda" if torch.cuda.is_available() else "cpu") torch.backends.cudnn.benchmark = True torch.backends.cudnn.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True os.environ["RWKV_V7_ON"] = "1" # enable this for rwkv-7 models os.environ["RWKV_JIT_ON"] = "1" os.environ["RWKV_CUDA_ON"] = ( "0" # !!! '1' to compile CUDA kernel (10x faster), requires c++ compiler & cuda libraries !!! ) from rwkv.model import RWKV from rwkv.utils import PIPELINE, PIPELINE_ARGS from fastapi import FastAPI from fastapi.responses import StreamingResponse from fastapi.middleware.cors import CORSMiddleware from api_types import ( ChatMessage, ChatCompletion, ChatCompletionChunk, Usage, PromptTokensDetails, ChatCompletionChoice, ChatCompletionMessage, ) from utils import cleanMessages, parse_think_response logger.info(f"STRATEGY - {CONFIG.STRATEGY}") if CONFIG.MODEL_FILE_PATH == None: CONFIG.MODEL_FILE_PATH = hf_hub_download( repo_id=CONFIG.DOWNLOAD_REPO_ID, filename=f"{CONFIG.MODEL_TITLE}.pth", local_dir=CONFIG.DOWNLOAD_MODEL_DIR, ) logger.info(f"Load Model - {CONFIG.MODEL_FILE_PATH}") model = RWKV(model=CONFIG.MODEL_FILE_PATH.replace(".pth", ""), strategy=CONFIG.STRATEGY) pipeline = PIPELINE(model, CONFIG.VOCAB) class ChatCompletionRequest(BaseModel): model: str = Field( default="rwkv-latest", description="Add `:thinking` suffix to the model name to enable reasoning. Example: `rwkv-latest:thinking`", ) messages: List[ChatMessage] prompt: Union[str, None] = Field(default=None) max_tokens: int = Field(default=512) temperature: float = Field(default=1.0) top_p: float = Field(default=0.3) presencePenalty: float = Field(default=0.5) countPenalty: float = Field(default=0.5) stream: bool = Field(default=False) state_name: str = Field(default=None) include_usage: bool = Field(default=False) app = FastAPI(title="RWKV OpenAI-Compatible API") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) def runPrefill(ctx: str, model_tokens: List[int], model_state): ctx = ctx.replace("\r\n", "\n") tokens = pipeline.encode(ctx) tokens = [int(x) for x in tokens] model_tokens += tokens while len(tokens) > 0: out, model_state = model.forward(tokens[: CONFIG.CHUNK_LEN], model_state) tokens = tokens[CONFIG.CHUNK_LEN :] return out, model_tokens, model_state def generate( request: ChatCompletionRequest, out, model_tokens, model_state, stops=["\n\n"], max_tokens=2048, ): args = PIPELINE_ARGS( temperature=max(0.2, request.temperature), top_p=request.top_p, alpha_frequency=request.countPenalty, alpha_presence=request.presencePenalty, token_ban=[], # ban the generation of some tokens token_stop=[0], ) # stop generation whenever you see any token here occurrence = {} out_tokens = [] out_last = 0 output_cache = collections.deque(maxlen=5) for i in range(max_tokens): for n in occurrence: out[n] -= args.alpha_presence + occurrence[n] * args.alpha_frequency out[0] -= 1e10 # disable END_OF_TEXT token = pipeline.sample_logits( out, temperature=args.temperature, top_p=args.top_p ) out, model_state = model.forward([token], model_state) model_tokens += [token] out_tokens += [token] for xxx in occurrence: occurrence[xxx] *= CONFIG.GEN_penalty_decay occurrence[token] = 1 + (occurrence[token] if token in occurrence else 0) tmp: str = pipeline.decode(out_tokens[out_last:]) if "\ufffd" in tmp: continue output_cache.append(tmp) output_cache_str = "".join(output_cache) for stop_words in stops: if stop_words in output_cache_str: yield { "content": tmp.replace(stop_words, ""), "tokens": out_tokens[out_last:], "finish_reason": "stop", "state": model_state, } del out gc.collect() return yield { "content": tmp, "tokens": out_tokens[out_last:], "finish_reason": None, } out_last = i + 1 else: yield { "content": "", "tokens": [], "finish_reason": "length", } async def chatResponse( request: ChatCompletionRequest, model_state: any, completionId: str ) -> ChatCompletion: createTimestamp = time.time() enableReasoning = request.model.endswith(":thinking") prompt = ( f"{cleanMessages(request.messages)}\n\nAssistant:{' ", streamConfig["fullTextCursor"]) if streamConfig["isChecking"] and markEnd != -1: streamConfig["isChecking"] = False if ( not streamConfig["in_think"] and streamConfig["cacheStr"].find("") != -1 ): streamConfig["in_think"] = True response.choices[0].delta.reasoning_content = ( response.choices[0].delta.reasoning_content if response.choices[0].delta.reasoning_content != None else "" + streamConfig["cacheStr"].replace("", "") ) elif ( streamConfig["in_think"] and streamConfig["cacheStr"].find("") != -1 ): streamConfig["in_think"] = False response.choices[0].delta.content = ( response.choices[0].delta.content if response.choices[0].delta.content != None else "" + streamConfig["cacheStr"].replace("", "") ) else: if streamConfig["in_think"]: response.choices[0].delta.reasoning_content = ( response.choices[0].delta.reasoning_content if response.choices[0].delta.reasoning_content != None else "" + streamConfig["cacheStr"] ) else: response.choices[0].delta.content = ( response.choices[0].delta.content if response.choices[0].delta.content != None else "" + streamConfig["cacheStr"] ) streamConfig["fullTextCursor"] = len(fullText) if ( response.choices[0].delta.content != None or response.choices[0].delta.reasoning_content != None ): yield f"data: {response.model_dump_json()}\n\n" await asyncio.sleep(0) del streamConfig else: for chunk in generate(request, out, model_tokens, model_state): completionTokenCount += 1 buffer.append(chunk["content"]) if chunk["finish_reason"]: finishReason = chunk["finish_reason"] response = ChatCompletionChunk( id=completionId, created=createTimestamp, model=request.model, usage=( Usage( prompt_tokens=promptTokenCount, completion_tokens=completionTokenCount, total_tokens=promptTokenCount + completionTokenCount, prompt_tokens_details={"cached_tokens": 0}, ) if request.include_usage else None ), choices=[ ChatCompletionChoice( index=0, delta=ChatCompletionMessage(content=chunk["content"]), logprobs=None, finish_reason=finishReason, ) ], ) yield f"data: {response.model_dump_json()}\n\n" await asyncio.sleep(0) genenrateTime = time.time() responseLog = { "content": "".join(buffer), "finish": finishReason, "prefill_len": promptTokenCount, "prefill_tps": round(promptTokenCount / (prefillTime - createTimestamp), 2), "gen_len": completionTokenCount, "gen_tps": round(completionTokenCount / (genenrateTime - prefillTime), 2), } logger.info(f"[RES] {completionId} - {responseLog}") del buffer yield "data: [DONE]\n\n" @app.post("/api/v1/chat/completions") async def chat_completions(request: ChatCompletionRequest): completionId = str(next(CompletionIdGenerator)) logger.info(f"[REQ] {completionId} - {request.model_dump()}") def chatResponseStreamDisconnect(): gpu_info = nvmlDeviceGetMemoryInfo(gpu_h) logger.info( f"[STATUS] vram {gpu_info.total} used {gpu_info.used} free {gpu_info.free}" ) model_state = None if request.stream: r = StreamingResponse( chatResponseStream(request, model_state, completionId), media_type="text/event-stream", background=chatResponseStreamDisconnect, ) else: r = await chatResponse(request, model_state, completionId) return r if __name__ == "__main__": import uvicorn uvicorn.run(app, host=CONFIG.HOST, port=CONFIG.PORT)