Spaces:
Running
on
T4
Running
on
T4
sparkleman
commited on
Commit
·
9b9e15b
1
Parent(s):
d761c6a
UPDATE: support stop tokens
Browse files- Dockerfile +2 -1
- app.py +26 -6
- config.production.yaml +4 -0
- config.py +3 -2
- pyproject.toml +1 -0
- uv.lock +2 -0
Dockerfile
CHANGED
@@ -15,7 +15,8 @@ RUN ["cargo", "install", "wasm-pack"]
|
|
15 |
WORKDIR /app
|
16 |
ENV PATH=/root/.cargo/bin:$PATH
|
17 |
RUN npm install -g pnpm
|
18 |
-
RUN pnpm install
|
|
|
19 |
|
20 |
FROM nvidia/cuda:12.4.0-runtime-ubuntu22.04 AS Backend
|
21 |
|
|
|
15 |
WORKDIR /app
|
16 |
ENV PATH=/root/.cargo/bin:$PATH
|
17 |
RUN npm install -g pnpm
|
18 |
+
RUN pnpm install
|
19 |
+
RUN pnpm run build --mode target-rwkv-hf-space
|
20 |
|
21 |
FROM nvidia/cuda:12.4.0-runtime-ubuntu22.04 AS Backend
|
22 |
|
app.py
CHANGED
@@ -3,6 +3,7 @@ from config import CONFIG, ModelConfig
|
|
3 |
import os, copy, types, gc, sys, re, time, collections, asyncio
|
4 |
from huggingface_hub import hf_hub_download
|
5 |
from loguru import logger
|
|
|
6 |
|
7 |
from snowflake import SnowflakeGenerator
|
8 |
|
@@ -92,6 +93,8 @@ for model_config in CONFIG.MODELS:
|
|
92 |
else:
|
93 |
DEFALUT_MODEL_NAME = model_config.SERVICE_NAME
|
94 |
|
|
|
|
|
95 |
MODEL_STORAGE[model_config.SERVICE_NAME] = ModelStorage()
|
96 |
MODEL_STORAGE[model_config.SERVICE_NAME].MODEL_CONFIG = model_config
|
97 |
MODEL_STORAGE[model_config.SERVICE_NAME].model = tmp_model
|
@@ -119,6 +122,7 @@ class ChatCompletionRequest(BaseModel):
|
|
119 |
state_name: Optional[str] = Field(default=None)
|
120 |
include_usage: Optional[bool] = Field(default=False)
|
121 |
stop: Optional[list[str]] = Field(["\n\n"])
|
|
|
122 |
|
123 |
@model_validator(mode="before")
|
124 |
@classmethod
|
@@ -169,7 +173,7 @@ async def runPrefill(
|
|
169 |
def generate(
|
170 |
request: ChatCompletionRequest,
|
171 |
out,
|
172 |
-
model_tokens,
|
173 |
model_state,
|
174 |
stops=["\n\n"],
|
175 |
max_tokens=2048,
|
@@ -184,7 +188,7 @@ def generate(
|
|
184 |
) # stop generation whenever you see any token here
|
185 |
|
186 |
occurrence = {}
|
187 |
-
out_tokens = []
|
188 |
out_last = 0
|
189 |
|
190 |
output_cache = collections.deque(maxlen=5)
|
@@ -192,7 +196,7 @@ def generate(
|
|
192 |
for i in range(max_tokens):
|
193 |
for n in occurrence:
|
194 |
out[n] -= args.alpha_presence + occurrence[n] * args.alpha_frequency
|
195 |
-
out[0] -= 1e10 # disable END_OF_TEXT
|
196 |
|
197 |
token = MODEL_STORAGE[request.model].pipeline.sample_logits(
|
198 |
out, temperature=args.temperature, top_p=args.top_p
|
@@ -201,9 +205,21 @@ def generate(
|
|
201 |
out, model_state = MODEL_STORAGE[request.model].model.forward(
|
202 |
[token], model_state
|
203 |
)
|
204 |
-
model_tokens
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
205 |
|
206 |
-
|
|
|
|
|
207 |
|
208 |
for xxx in occurrence:
|
209 |
occurrence[xxx] *= request.penalty_decay
|
@@ -260,6 +276,7 @@ async def chatResponse(
|
|
260 |
if request.prompt == None
|
261 |
else request.prompt.strip()
|
262 |
)
|
|
|
263 |
|
264 |
out, model_tokens, model_state = await runPrefill(request, prompt, [], model_state)
|
265 |
|
@@ -343,6 +360,8 @@ async def chatResponseStream(
|
|
343 |
else request.prompt.strip()
|
344 |
)
|
345 |
|
|
|
|
|
346 |
out, model_tokens, model_state = await runPrefill(request, prompt, [], model_state)
|
347 |
|
348 |
prefillTime = time.time()
|
@@ -465,7 +484,7 @@ async def chatResponseStream(
|
|
465 |
streamConfig["fullTextCursor"] = len(fullText)
|
466 |
|
467 |
markEnd = fullText.find(">", streamConfig["fullTextCursor"])
|
468 |
-
if streamConfig["isChecking"] and markEnd != -1:
|
469 |
streamConfig["isChecking"] = False
|
470 |
|
471 |
if (
|
@@ -626,6 +645,7 @@ async def chat_completions(request: ChatCompletionRequest):
|
|
626 |
|
627 |
return r
|
628 |
|
|
|
629 |
app.mount("/", StaticFiles(directory="dist-frontend", html=True), name="static")
|
630 |
|
631 |
if __name__ == "__main__":
|
|
|
3 |
import os, copy, types, gc, sys, re, time, collections, asyncio
|
4 |
from huggingface_hub import hf_hub_download
|
5 |
from loguru import logger
|
6 |
+
from rich import print
|
7 |
|
8 |
from snowflake import SnowflakeGenerator
|
9 |
|
|
|
93 |
else:
|
94 |
DEFALUT_MODEL_NAME = model_config.SERVICE_NAME
|
95 |
|
96 |
+
print(model_config.DEFAULT_SAMPLER)
|
97 |
+
|
98 |
MODEL_STORAGE[model_config.SERVICE_NAME] = ModelStorage()
|
99 |
MODEL_STORAGE[model_config.SERVICE_NAME].MODEL_CONFIG = model_config
|
100 |
MODEL_STORAGE[model_config.SERVICE_NAME].model = tmp_model
|
|
|
122 |
state_name: Optional[str] = Field(default=None)
|
123 |
include_usage: Optional[bool] = Field(default=False)
|
124 |
stop: Optional[list[str]] = Field(["\n\n"])
|
125 |
+
stop_tokens: Optional[list[int]] = Field([0])
|
126 |
|
127 |
@model_validator(mode="before")
|
128 |
@classmethod
|
|
|
173 |
def generate(
|
174 |
request: ChatCompletionRequest,
|
175 |
out,
|
176 |
+
model_tokens: List[int],
|
177 |
model_state,
|
178 |
stops=["\n\n"],
|
179 |
max_tokens=2048,
|
|
|
188 |
) # stop generation whenever you see any token here
|
189 |
|
190 |
occurrence = {}
|
191 |
+
out_tokens: List[int] = []
|
192 |
out_last = 0
|
193 |
|
194 |
output_cache = collections.deque(maxlen=5)
|
|
|
196 |
for i in range(max_tokens):
|
197 |
for n in occurrence:
|
198 |
out[n] -= args.alpha_presence + occurrence[n] * args.alpha_frequency
|
199 |
+
# out[0] -= 1e10 # disable END_OF_TEXT
|
200 |
|
201 |
token = MODEL_STORAGE[request.model].pipeline.sample_logits(
|
202 |
out, temperature=args.temperature, top_p=args.top_p
|
|
|
205 |
out, model_state = MODEL_STORAGE[request.model].model.forward(
|
206 |
[token], model_state
|
207 |
)
|
208 |
+
model_tokens.append(token)
|
209 |
+
|
210 |
+
out_tokens.append(token)
|
211 |
+
|
212 |
+
if token in request.stop_tokens:
|
213 |
+
yield {
|
214 |
+
"content": "",
|
215 |
+
"tokens": out_tokens[out_last:],
|
216 |
+
"finish_reason": "stop",
|
217 |
+
"state": model_state,
|
218 |
+
}
|
219 |
|
220 |
+
del out
|
221 |
+
gc.collect()
|
222 |
+
return
|
223 |
|
224 |
for xxx in occurrence:
|
225 |
occurrence[xxx] *= request.penalty_decay
|
|
|
276 |
if request.prompt == None
|
277 |
else request.prompt.strip()
|
278 |
)
|
279 |
+
logger.info(f"[REQ] {completionId} - prompt - {prompt}")
|
280 |
|
281 |
out, model_tokens, model_state = await runPrefill(request, prompt, [], model_state)
|
282 |
|
|
|
360 |
else request.prompt.strip()
|
361 |
)
|
362 |
|
363 |
+
# logger.info(f"[REQ] {completionId} - prompt - {prompt}")
|
364 |
+
|
365 |
out, model_tokens, model_state = await runPrefill(request, prompt, [], model_state)
|
366 |
|
367 |
prefillTime = time.time()
|
|
|
484 |
streamConfig["fullTextCursor"] = len(fullText)
|
485 |
|
486 |
markEnd = fullText.find(">", streamConfig["fullTextCursor"])
|
487 |
+
if (streamConfig["isChecking"] and markEnd != -1) or finishReason != None:
|
488 |
streamConfig["isChecking"] = False
|
489 |
|
490 |
if (
|
|
|
645 |
|
646 |
return r
|
647 |
|
648 |
+
|
649 |
app.mount("/", StaticFiles(directory="dist-frontend", html=True), name="static")
|
650 |
|
651 |
if __name__ == "__main__":
|
config.production.yaml
CHANGED
@@ -18,6 +18,8 @@ MODELS:
|
|
18 |
penalty_decay: 0.996
|
19 |
stop:
|
20 |
- "\n\n"
|
|
|
|
|
21 |
- SERVICE_NAME: "rwkv7-g1-0.1b-20250307-ctx4096"
|
22 |
DOWNLOAD_MODEL_FILE_NAME: "rwkv7-g1-0.1b-20250307-ctx4096.pth"
|
23 |
DOWNLOAD_MODEL_REPO_ID: "BlinkDL/rwkv7-g1"
|
@@ -32,3 +34,5 @@ MODELS:
|
|
32 |
penalty_decay: 0.996
|
33 |
stop:
|
34 |
- "\n\n"
|
|
|
|
|
|
18 |
penalty_decay: 0.996
|
19 |
stop:
|
20 |
- "\n\n"
|
21 |
+
stop_tokens:
|
22 |
+
- 0
|
23 |
- SERVICE_NAME: "rwkv7-g1-0.1b-20250307-ctx4096"
|
24 |
DOWNLOAD_MODEL_FILE_NAME: "rwkv7-g1-0.1b-20250307-ctx4096.pth"
|
25 |
DOWNLOAD_MODEL_REPO_ID: "BlinkDL/rwkv7-g1"
|
|
|
34 |
penalty_decay: 0.996
|
35 |
stop:
|
36 |
- "\n\n"
|
37 |
+
stop_tokens:
|
38 |
+
- 0
|
config.py
CHANGED
@@ -23,8 +23,9 @@ class SamplerConfig(BaseModel):
|
|
23 |
top_p: float = Field(0.3, description="Top-p sampling threshold.")
|
24 |
presence_penalty: float = Field(0.5, description="Presence penalty.")
|
25 |
count_penalty: float = Field(0.5, description="Count penalty.")
|
26 |
-
penalty_decay: float = Field(0.
|
27 |
-
stop: List[str] = Field(
|
|
|
28 |
|
29 |
|
30 |
class ModelConfig(BaseModel):
|
|
|
23 |
top_p: float = Field(0.3, description="Top-p sampling threshold.")
|
24 |
presence_penalty: float = Field(0.5, description="Presence penalty.")
|
25 |
count_penalty: float = Field(0.5, description="Count penalty.")
|
26 |
+
penalty_decay: float = Field(0.996, description="Penalty decay factor.")
|
27 |
+
stop: List[str] = Field(["\n\n"], description="List of stop sequences.")
|
28 |
+
stop_tokens: List[int] = Field([0], description="List of stop tokens.")
|
29 |
|
30 |
|
31 |
class ModelConfig(BaseModel):
|
pyproject.toml
CHANGED
@@ -13,6 +13,7 @@ dependencies = [
|
|
13 |
"pydantic>=2.10.6",
|
14 |
"pydantic-settings>=2.8.1",
|
15 |
"pynvml>=12.0.0",
|
|
|
16 |
"rwkv==0.8.28",
|
17 |
"setuptools>=75.8.2",
|
18 |
"snowflake-id>=1.0.2",
|
|
|
13 |
"pydantic>=2.10.6",
|
14 |
"pydantic-settings>=2.8.1",
|
15 |
"pynvml>=12.0.0",
|
16 |
+
"rich>=13.9.4",
|
17 |
"rwkv==0.8.28",
|
18 |
"setuptools>=75.8.2",
|
19 |
"snowflake-id>=1.0.2",
|
uv.lock
CHANGED
@@ -944,6 +944,7 @@ dependencies = [
|
|
944 |
{ name = "pydantic" },
|
945 |
{ name = "pydantic-settings" },
|
946 |
{ name = "pynvml" },
|
|
|
947 |
{ name = "rwkv" },
|
948 |
{ name = "setuptools" },
|
949 |
{ name = "snowflake-id" },
|
@@ -971,6 +972,7 @@ requires-dist = [
|
|
971 |
{ name = "pydantic", specifier = ">=2.10.6" },
|
972 |
{ name = "pydantic-settings", specifier = ">=2.8.1" },
|
973 |
{ name = "pynvml", specifier = ">=12.0.0" },
|
|
|
974 |
{ name = "rwkv", specifier = "==0.8.28" },
|
975 |
{ name = "setuptools", specifier = ">=75.8.2" },
|
976 |
{ name = "snowflake-id", specifier = ">=1.0.2" },
|
|
|
944 |
{ name = "pydantic" },
|
945 |
{ name = "pydantic-settings" },
|
946 |
{ name = "pynvml" },
|
947 |
+
{ name = "rich" },
|
948 |
{ name = "rwkv" },
|
949 |
{ name = "setuptools" },
|
950 |
{ name = "snowflake-id" },
|
|
|
972 |
{ name = "pydantic", specifier = ">=2.10.6" },
|
973 |
{ name = "pydantic-settings", specifier = ">=2.8.1" },
|
974 |
{ name = "pynvml", specifier = ">=12.0.0" },
|
975 |
+
{ name = "rich", specifier = ">=13.9.4" },
|
976 |
{ name = "rwkv", specifier = "==0.8.28" },
|
977 |
{ name = "setuptools", specifier = ">=75.8.2" },
|
978 |
{ name = "snowflake-id", specifier = ">=1.0.2" },
|