freeCS-dot-org commited on
Commit
3bce535
·
verified ·
1 Parent(s): 0916816

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -15
app.py CHANGED
@@ -6,11 +6,10 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStream
6
  import gradio as gr
7
  from threading import Thread
8
 
9
-
10
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
11
  MODEL = "AGI-0/Art-v0-3B"
12
 
13
- TITLE = """<h2>Link to the model: <a href="https://huggingface.co/AGI-0/Art-v0-3B"</h2>"""
14
 
15
  PLACEHOLDER = """
16
  <center>
@@ -18,7 +17,6 @@ PLACEHOLDER = """
18
  </center>
19
  """
20
 
21
-
22
  CSS = """
23
  .duplicate-button {
24
  margin: auto !important;
@@ -31,7 +29,7 @@ h3 {
31
  }
32
  """
33
 
34
- device = "cuda" # for GPU usage or "cpu" for CPU usage
35
 
36
  tokenizer = AutoTokenizer.from_pretrained(MODEL)
37
  model = AutoModelForCausalLM.from_pretrained(
@@ -39,6 +37,8 @@ model = AutoModelForCausalLM.from_pretrained(
39
  torch_dtype=torch.bfloat16,
40
  device_map="auto")
41
  end_of_sentence = tokenizer.convert_tokens_to_ids("<|im_end|>")
 
 
42
  @spaces.GPU()
43
  def stream_chat(
44
  message: str,
@@ -53,8 +53,7 @@ def stream_chat(
53
  print(f'message: {message}')
54
  print(f'history: {history}')
55
 
56
- conversation = [
57
- ]
58
  for prompt, answer in history:
59
  conversation.extend([
60
  {"role": "user", "content": prompt},
@@ -69,11 +68,11 @@ def stream_chat(
69
 
70
  generate_kwargs = dict(
71
  input_ids=input_ids,
72
- max_new_tokens = max_new_tokens,
73
- do_sample = False if temperature == 0 else True,
74
- top_p = top_p,
75
- top_k = top_k,
76
- temperature = temperature,
77
  repetition_penalty=penalty,
78
  eos_token_id=[end_of_sentence],
79
  streamer=streamer,
@@ -82,12 +81,42 @@ def stream_chat(
82
  with torch.no_grad():
83
  thread = Thread(target=model.generate, kwargs=generate_kwargs)
84
  thread.start()
85
-
86
  buffer = ""
 
 
 
 
87
  for new_text in streamer:
88
  buffer += new_text
89
- yield buffer
90
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
  chatbot = gr.Chatbot(height=600, placeholder=PLACEHOLDER)
93
 
@@ -155,6 +184,5 @@ with gr.Blocks(css=CSS, theme="soft") as demo:
155
  cache_examples=False,
156
  )
157
 
158
-
159
  if __name__ == "__main__":
160
  demo.launch()
 
6
  import gradio as gr
7
  from threading import Thread
8
 
 
9
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
10
  MODEL = "AGI-0/Art-v0-3B"
11
 
12
+ TITLE = """<h2>Link to the model: <a href="https://huggingface.co/AGI-0/Art-v0-3B">click here</a></h2>"""
13
 
14
  PLACEHOLDER = """
15
  <center>
 
17
  </center>
18
  """
19
 
 
20
  CSS = """
21
  .duplicate-button {
22
  margin: auto !important;
 
29
  }
30
  """
31
 
32
+ device = "cuda" # for GPU usage or "cpu" for CPU usage
33
 
34
  tokenizer = AutoTokenizer.from_pretrained(MODEL)
35
  model = AutoModelForCausalLM.from_pretrained(
 
37
  torch_dtype=torch.bfloat16,
38
  device_map="auto")
39
  end_of_sentence = tokenizer.convert_tokens_to_ids("<|im_end|>")
40
+ end_reasoning_token = "<|end_reasoning|>"
41
+
42
  @spaces.GPU()
43
  def stream_chat(
44
  message: str,
 
53
  print(f'message: {message}')
54
  print(f'history: {history}')
55
 
56
+ conversation = []
 
57
  for prompt, answer in history:
58
  conversation.extend([
59
  {"role": "user", "content": prompt},
 
68
 
69
  generate_kwargs = dict(
70
  input_ids=input_ids,
71
+ max_new_tokens=max_new_tokens,
72
+ do_sample=False if temperature == 0 else True,
73
+ top_p=top_p,
74
+ top_k=top_k,
75
+ temperature=temperature,
76
  repetition_penalty=penalty,
77
  eos_token_id=[end_of_sentence],
78
  streamer=streamer,
 
81
  with torch.no_grad():
82
  thread = Thread(target=model.generate, kwargs=generate_kwargs)
83
  thread.start()
84
+
85
  buffer = ""
86
+ reasoning_text = ""
87
+ final_text = ""
88
+ in_reasoning = True
89
+
90
  for new_text in streamer:
91
  buffer += new_text
92
+
93
+ if end_reasoning_token in buffer and in_reasoning:
94
+ # Split the buffer at the end_reasoning_token
95
+ parts = buffer.split(end_reasoning_token)
96
+ reasoning_text = parts[0]
97
+ final_text = parts[1] if len(parts) > 1 else ""
98
+
99
+ # Format the output with the details tag
100
+ formatted_output = (
101
+ "<details><summary>Click to see reasoning</summary>\n\n"
102
+ f"{reasoning_text}\n\n"
103
+ "</details>\n\n"
104
+ f"{final_text}"
105
+ )
106
+ in_reasoning = False
107
+ yield formatted_output
108
+ elif in_reasoning:
109
+ # Still collecting reasoning text
110
+ yield "<details><summary>Click to see reasoning</summary>\n\n" + buffer + "\n\n</details>"
111
+ else:
112
+ # After end_reasoning_token, just append to the existing formatted output
113
+ formatted_output = (
114
+ "<details><summary>Click to see reasoning</summary>\n\n"
115
+ f"{reasoning_text}\n\n"
116
+ "</details>\n\n"
117
+ f"{buffer}"
118
+ )
119
+ yield formatted_output
120
 
121
  chatbot = gr.Chatbot(height=600, placeholder=PLACEHOLDER)
122
 
 
184
  cache_examples=False,
185
  )
186
 
 
187
  if __name__ == "__main__":
188
  demo.launch()