sparkleman commited on
Commit
9b9e15b
·
1 Parent(s): d761c6a

UPDATE: support stop tokens

Browse files
Files changed (6) hide show
  1. Dockerfile +2 -1
  2. app.py +26 -6
  3. config.production.yaml +4 -0
  4. config.py +3 -2
  5. pyproject.toml +1 -0
  6. uv.lock +2 -0
Dockerfile CHANGED
@@ -15,7 +15,8 @@ RUN ["cargo", "install", "wasm-pack"]
15
  WORKDIR /app
16
  ENV PATH=/root/.cargo/bin:$PATH
17
  RUN npm install -g pnpm
18
- RUN pnpm install && pnpm run build
 
19
 
20
  FROM nvidia/cuda:12.4.0-runtime-ubuntu22.04 AS Backend
21
 
 
15
  WORKDIR /app
16
  ENV PATH=/root/.cargo/bin:$PATH
17
  RUN npm install -g pnpm
18
+ RUN pnpm install
19
+ RUN pnpm run build --mode target-rwkv-hf-space
20
 
21
  FROM nvidia/cuda:12.4.0-runtime-ubuntu22.04 AS Backend
22
 
app.py CHANGED
@@ -3,6 +3,7 @@ from config import CONFIG, ModelConfig
3
  import os, copy, types, gc, sys, re, time, collections, asyncio
4
  from huggingface_hub import hf_hub_download
5
  from loguru import logger
 
6
 
7
  from snowflake import SnowflakeGenerator
8
 
@@ -92,6 +93,8 @@ for model_config in CONFIG.MODELS:
92
  else:
93
  DEFALUT_MODEL_NAME = model_config.SERVICE_NAME
94
 
 
 
95
  MODEL_STORAGE[model_config.SERVICE_NAME] = ModelStorage()
96
  MODEL_STORAGE[model_config.SERVICE_NAME].MODEL_CONFIG = model_config
97
  MODEL_STORAGE[model_config.SERVICE_NAME].model = tmp_model
@@ -119,6 +122,7 @@ class ChatCompletionRequest(BaseModel):
119
  state_name: Optional[str] = Field(default=None)
120
  include_usage: Optional[bool] = Field(default=False)
121
  stop: Optional[list[str]] = Field(["\n\n"])
 
122
 
123
  @model_validator(mode="before")
124
  @classmethod
@@ -169,7 +173,7 @@ async def runPrefill(
169
  def generate(
170
  request: ChatCompletionRequest,
171
  out,
172
- model_tokens,
173
  model_state,
174
  stops=["\n\n"],
175
  max_tokens=2048,
@@ -184,7 +188,7 @@ def generate(
184
  ) # stop generation whenever you see any token here
185
 
186
  occurrence = {}
187
- out_tokens = []
188
  out_last = 0
189
 
190
  output_cache = collections.deque(maxlen=5)
@@ -192,7 +196,7 @@ def generate(
192
  for i in range(max_tokens):
193
  for n in occurrence:
194
  out[n] -= args.alpha_presence + occurrence[n] * args.alpha_frequency
195
- out[0] -= 1e10 # disable END_OF_TEXT
196
 
197
  token = MODEL_STORAGE[request.model].pipeline.sample_logits(
198
  out, temperature=args.temperature, top_p=args.top_p
@@ -201,9 +205,21 @@ def generate(
201
  out, model_state = MODEL_STORAGE[request.model].model.forward(
202
  [token], model_state
203
  )
204
- model_tokens += [token]
 
 
 
 
 
 
 
 
 
 
205
 
206
- out_tokens += [token]
 
 
207
 
208
  for xxx in occurrence:
209
  occurrence[xxx] *= request.penalty_decay
@@ -260,6 +276,7 @@ async def chatResponse(
260
  if request.prompt == None
261
  else request.prompt.strip()
262
  )
 
263
 
264
  out, model_tokens, model_state = await runPrefill(request, prompt, [], model_state)
265
 
@@ -343,6 +360,8 @@ async def chatResponseStream(
343
  else request.prompt.strip()
344
  )
345
 
 
 
346
  out, model_tokens, model_state = await runPrefill(request, prompt, [], model_state)
347
 
348
  prefillTime = time.time()
@@ -465,7 +484,7 @@ async def chatResponseStream(
465
  streamConfig["fullTextCursor"] = len(fullText)
466
 
467
  markEnd = fullText.find(">", streamConfig["fullTextCursor"])
468
- if streamConfig["isChecking"] and markEnd != -1:
469
  streamConfig["isChecking"] = False
470
 
471
  if (
@@ -626,6 +645,7 @@ async def chat_completions(request: ChatCompletionRequest):
626
 
627
  return r
628
 
 
629
  app.mount("/", StaticFiles(directory="dist-frontend", html=True), name="static")
630
 
631
  if __name__ == "__main__":
 
3
  import os, copy, types, gc, sys, re, time, collections, asyncio
4
  from huggingface_hub import hf_hub_download
5
  from loguru import logger
6
+ from rich import print
7
 
8
  from snowflake import SnowflakeGenerator
9
 
 
93
  else:
94
  DEFALUT_MODEL_NAME = model_config.SERVICE_NAME
95
 
96
+ print(model_config.DEFAULT_SAMPLER)
97
+
98
  MODEL_STORAGE[model_config.SERVICE_NAME] = ModelStorage()
99
  MODEL_STORAGE[model_config.SERVICE_NAME].MODEL_CONFIG = model_config
100
  MODEL_STORAGE[model_config.SERVICE_NAME].model = tmp_model
 
122
  state_name: Optional[str] = Field(default=None)
123
  include_usage: Optional[bool] = Field(default=False)
124
  stop: Optional[list[str]] = Field(["\n\n"])
125
+ stop_tokens: Optional[list[int]] = Field([0])
126
 
127
  @model_validator(mode="before")
128
  @classmethod
 
173
  def generate(
174
  request: ChatCompletionRequest,
175
  out,
176
+ model_tokens: List[int],
177
  model_state,
178
  stops=["\n\n"],
179
  max_tokens=2048,
 
188
  ) # stop generation whenever you see any token here
189
 
190
  occurrence = {}
191
+ out_tokens: List[int] = []
192
  out_last = 0
193
 
194
  output_cache = collections.deque(maxlen=5)
 
196
  for i in range(max_tokens):
197
  for n in occurrence:
198
  out[n] -= args.alpha_presence + occurrence[n] * args.alpha_frequency
199
+ # out[0] -= 1e10 # disable END_OF_TEXT
200
 
201
  token = MODEL_STORAGE[request.model].pipeline.sample_logits(
202
  out, temperature=args.temperature, top_p=args.top_p
 
205
  out, model_state = MODEL_STORAGE[request.model].model.forward(
206
  [token], model_state
207
  )
208
+ model_tokens.append(token)
209
+
210
+ out_tokens.append(token)
211
+
212
+ if token in request.stop_tokens:
213
+ yield {
214
+ "content": "",
215
+ "tokens": out_tokens[out_last:],
216
+ "finish_reason": "stop",
217
+ "state": model_state,
218
+ }
219
 
220
+ del out
221
+ gc.collect()
222
+ return
223
 
224
  for xxx in occurrence:
225
  occurrence[xxx] *= request.penalty_decay
 
276
  if request.prompt == None
277
  else request.prompt.strip()
278
  )
279
+ logger.info(f"[REQ] {completionId} - prompt - {prompt}")
280
 
281
  out, model_tokens, model_state = await runPrefill(request, prompt, [], model_state)
282
 
 
360
  else request.prompt.strip()
361
  )
362
 
363
+ # logger.info(f"[REQ] {completionId} - prompt - {prompt}")
364
+
365
  out, model_tokens, model_state = await runPrefill(request, prompt, [], model_state)
366
 
367
  prefillTime = time.time()
 
484
  streamConfig["fullTextCursor"] = len(fullText)
485
 
486
  markEnd = fullText.find(">", streamConfig["fullTextCursor"])
487
+ if (streamConfig["isChecking"] and markEnd != -1) or finishReason != None:
488
  streamConfig["isChecking"] = False
489
 
490
  if (
 
645
 
646
  return r
647
 
648
+
649
  app.mount("/", StaticFiles(directory="dist-frontend", html=True), name="static")
650
 
651
  if __name__ == "__main__":
config.production.yaml CHANGED
@@ -18,6 +18,8 @@ MODELS:
18
  penalty_decay: 0.996
19
  stop:
20
  - "\n\n"
 
 
21
  - SERVICE_NAME: "rwkv7-g1-0.1b-20250307-ctx4096"
22
  DOWNLOAD_MODEL_FILE_NAME: "rwkv7-g1-0.1b-20250307-ctx4096.pth"
23
  DOWNLOAD_MODEL_REPO_ID: "BlinkDL/rwkv7-g1"
@@ -32,3 +34,5 @@ MODELS:
32
  penalty_decay: 0.996
33
  stop:
34
  - "\n\n"
 
 
 
18
  penalty_decay: 0.996
19
  stop:
20
  - "\n\n"
21
+ stop_tokens:
22
+ - 0
23
  - SERVICE_NAME: "rwkv7-g1-0.1b-20250307-ctx4096"
24
  DOWNLOAD_MODEL_FILE_NAME: "rwkv7-g1-0.1b-20250307-ctx4096.pth"
25
  DOWNLOAD_MODEL_REPO_ID: "BlinkDL/rwkv7-g1"
 
34
  penalty_decay: 0.996
35
  stop:
36
  - "\n\n"
37
+ stop_tokens:
38
+ - 0
config.py CHANGED
@@ -23,8 +23,9 @@ class SamplerConfig(BaseModel):
23
  top_p: float = Field(0.3, description="Top-p sampling threshold.")
24
  presence_penalty: float = Field(0.5, description="Presence penalty.")
25
  count_penalty: float = Field(0.5, description="Count penalty.")
26
- penalty_decay: float = Field(0.5, description="Penalty decay factor.")
27
- stop: List[str] = Field(0.996, description="List of stop sequences.")
 
28
 
29
 
30
  class ModelConfig(BaseModel):
 
23
  top_p: float = Field(0.3, description="Top-p sampling threshold.")
24
  presence_penalty: float = Field(0.5, description="Presence penalty.")
25
  count_penalty: float = Field(0.5, description="Count penalty.")
26
+ penalty_decay: float = Field(0.996, description="Penalty decay factor.")
27
+ stop: List[str] = Field(["\n\n"], description="List of stop sequences.")
28
+ stop_tokens: List[int] = Field([0], description="List of stop tokens.")
29
 
30
 
31
  class ModelConfig(BaseModel):
pyproject.toml CHANGED
@@ -13,6 +13,7 @@ dependencies = [
13
  "pydantic>=2.10.6",
14
  "pydantic-settings>=2.8.1",
15
  "pynvml>=12.0.0",
 
16
  "rwkv==0.8.28",
17
  "setuptools>=75.8.2",
18
  "snowflake-id>=1.0.2",
 
13
  "pydantic>=2.10.6",
14
  "pydantic-settings>=2.8.1",
15
  "pynvml>=12.0.0",
16
+ "rich>=13.9.4",
17
  "rwkv==0.8.28",
18
  "setuptools>=75.8.2",
19
  "snowflake-id>=1.0.2",
uv.lock CHANGED
@@ -944,6 +944,7 @@ dependencies = [
944
  { name = "pydantic" },
945
  { name = "pydantic-settings" },
946
  { name = "pynvml" },
 
947
  { name = "rwkv" },
948
  { name = "setuptools" },
949
  { name = "snowflake-id" },
@@ -971,6 +972,7 @@ requires-dist = [
971
  { name = "pydantic", specifier = ">=2.10.6" },
972
  { name = "pydantic-settings", specifier = ">=2.8.1" },
973
  { name = "pynvml", specifier = ">=12.0.0" },
 
974
  { name = "rwkv", specifier = "==0.8.28" },
975
  { name = "setuptools", specifier = ">=75.8.2" },
976
  { name = "snowflake-id", specifier = ">=1.0.2" },
 
944
  { name = "pydantic" },
945
  { name = "pydantic-settings" },
946
  { name = "pynvml" },
947
+ { name = "rich" },
948
  { name = "rwkv" },
949
  { name = "setuptools" },
950
  { name = "snowflake-id" },
 
972
  { name = "pydantic", specifier = ">=2.10.6" },
973
  { name = "pydantic-settings", specifier = ">=2.8.1" },
974
  { name = "pynvml", specifier = ">=12.0.0" },
975
+ { name = "rich", specifier = ">=13.9.4" },
976
  { name = "rwkv", specifier = "==0.8.28" },
977
  { name = "setuptools", specifier = ">=75.8.2" },
978
  { name = "snowflake-id", specifier = ">=1.0.2" },