freeCS-dot-org commited on
Commit
bacf4cd
·
verified ·
1 Parent(s): f80f6ce

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +77 -35
app.py CHANGED
@@ -29,44 +29,79 @@ h3 {
29
  }
30
  """
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  device = "cuda" # for GPU usage or "cpu" for CPU usage
33
 
34
  tokenizer = AutoTokenizer.from_pretrained(MODEL)
35
  model = AutoModelForCausalLM.from_pretrained(
36
  MODEL,
37
  torch_dtype=torch.bfloat16,
38
- device_map="auto")
 
39
  end_of_sentence = tokenizer.convert_tokens_to_ids("<|im_end|>")
40
 
 
 
 
 
 
 
 
 
 
41
  @spaces.GPU()
42
  def stream_chat(
43
- message: str,
44
  history: list,
45
  system_prompt: str,
46
- temperature: float = 0.2,
47
- max_new_tokens: int = 4096,
48
- top_p: float = 1.0,
49
- top_k: int = 1,
50
  penalty: float = 1.1,
51
  ):
52
- print(f'message: {message}')
53
- print(f'history: {history}')
54
-
55
  conversation = []
56
- for prompt, answer in history:
57
  conversation.extend([
58
- {"role": "user", "content": prompt},
59
  {"role": "assistant", "content": answer},
60
  ])
61
-
62
  conversation.append({"role": "user", "content": message})
63
-
64
- input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt").to(model.device)
65
 
66
- streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
67
 
68
  generate_kwargs = dict(
69
- input_ids=input_ids,
70
  max_new_tokens=max_new_tokens,
71
  do_sample=False if temperature == 0 else True,
72
  top_p=top_p,
@@ -76,43 +111,50 @@ def stream_chat(
76
  eos_token_id=[end_of_sentence],
77
  streamer=streamer,
78
  )
79
-
80
- with torch.no_grad():
81
- thread = Thread(target=model.generate, kwargs=generate_kwargs)
82
- thread.start()
83
 
84
  buffer = ""
85
- found_token = False
86
 
87
- for new_text in streamer:
88
- buffer += new_text
 
89
 
90
- if "<|end_reasoning|>" in buffer and not found_token:
91
- # Split at the token
92
- parts = buffer.split("<|end_reasoning|>")
93
- reasoning = parts[0]
94
- rest = parts[1] if len(parts) > 1 else ""
95
 
96
- # Format with markdown and continue
97
- buffer = f"<details><summary>Click to see reasoning</summary>\n\n{reasoning}\n\n</details>\n\n{rest}"
98
- found_token = True
99
 
100
- yield buffer
 
 
 
 
 
 
 
101
 
102
  chatbot = gr.Chatbot(height=600, placeholder=PLACEHOLDER)
103
 
104
  with gr.Blocks(css=CSS, theme="soft") as demo:
105
  gr.HTML(TITLE)
106
- gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
 
 
 
107
  gr.ChatInterface(
108
  fn=stream_chat,
109
  chatbot=chatbot,
110
  fill_height=True,
111
- additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
 
 
 
 
112
  additional_inputs=[
113
  gr.Textbox(
114
  value="",
115
- label="",
116
  render=False,
117
  ),
118
  gr.Slider(
 
29
  }
30
  """
31
 
32
+ class ConversationManager:
33
+ def __init__(self):
34
+ self.user_history = [] # For displaying to user (with markdown)
35
+ self.model_history = [] # For feeding back to model (with original tags)
36
+
37
+ def add_exchange(self, user_message, assistant_response, formatted_response):
38
+ self.model_history.append((user_message, assistant_response))
39
+ self.user_history.append((user_message, formatted_response))
40
+
41
+ def get_model_history(self):
42
+ return self.model_history
43
+
44
+ def get_user_history(self):
45
+ return self.user_history
46
+
47
+ conversation_manager = ConversationManager()
48
+
49
  device = "cuda" # for GPU usage or "cpu" for CPU usage
50
 
51
  tokenizer = AutoTokenizer.from_pretrained(MODEL)
52
  model = AutoModelForCausalLM.from_pretrained(
53
  MODEL,
54
  torch_dtype=torch.bfloat16,
55
+ device_map="auto"
56
+ )
57
  end_of_sentence = tokenizer.convert_tokens_to_ids("<|im_end|>")
58
 
59
+ def format_response(response):
60
+ """Format the response for user display"""
61
+ if "<|end_reasoning|>" in response:
62
+ parts = response.split("<|end_reasoning|>")
63
+ reasoning = parts[0]
64
+ rest = parts[1] if len(parts) > 1 else ""
65
+ return f"<details><summary>Click to see reasoning</summary>\n\n{reasoning}\n\n</details>\n\n{rest}"
66
+ return response
67
+
68
  @spaces.GPU()
69
  def stream_chat(
70
+ message: str,
71
  history: list,
72
  system_prompt: str,
73
+ temperature: float = 0.2,
74
+ max_new_tokens: int = 4096,
75
+ top_p: float = 1.0,
76
+ top_k: int = 1,
77
  penalty: float = 1.1,
78
  ):
79
+ model_history = conversation_manager.get_model_history()
80
+
 
81
  conversation = []
82
+ for prompt, answer in model_history:
83
  conversation.extend([
84
+ {"role": "user", "content": prompt},
85
  {"role": "assistant", "content": answer},
86
  ])
87
+
88
  conversation.append({"role": "user", "content": message})
 
 
89
 
90
+ input_ids = tokenizer.apply_chat_template(
91
+ conversation,
92
+ add_generation_prompt=True,
93
+ return_tensors="pt"
94
+ ).to(model.device)
95
+
96
+ streamer = TextIteratorStreamer(
97
+ tokenizer,
98
+ timeout=60.0,
99
+ skip_prompt=True,
100
+ skip_special_tokens=True
101
+ )
102
 
103
  generate_kwargs = dict(
104
+ input_ids=input_ids,
105
  max_new_tokens=max_new_tokens,
106
  do_sample=False if temperature == 0 else True,
107
  top_p=top_p,
 
111
  eos_token_id=[end_of_sentence],
112
  streamer=streamer,
113
  )
 
 
 
 
114
 
115
  buffer = ""
116
+ original_response = ""
117
 
118
+ with torch.no_grad():
119
+ thread = Thread(target=model.generate, kwargs=generate_kwargs)
120
+ thread.start()
121
 
122
+ for new_text in streamer:
123
+ buffer += new_text
124
+ original_response += new_text
 
 
125
 
126
+ formatted_buffer = format_response(buffer)
 
 
127
 
128
+ if thread.is_alive() is False:
129
+ conversation_manager.add_exchange(
130
+ message,
131
+ original_response, # Original for model
132
+ formatted_buffer # Formatted for user
133
+ )
134
+
135
+ yield formatted_buffer
136
 
137
  chatbot = gr.Chatbot(height=600, placeholder=PLACEHOLDER)
138
 
139
  with gr.Blocks(css=CSS, theme="soft") as demo:
140
  gr.HTML(TITLE)
141
+ gr.DuplicateButton(
142
+ value="Duplicate Space for private use",
143
+ elem_classes="duplicate-button"
144
+ )
145
  gr.ChatInterface(
146
  fn=stream_chat,
147
  chatbot=chatbot,
148
  fill_height=True,
149
+ additional_inputs_accordion=gr.Accordion(
150
+ label="⚙️ Parameters",
151
+ open=False,
152
+ render=False
153
+ ),
154
  additional_inputs=[
155
  gr.Textbox(
156
  value="",
157
+ label="System Prompt",
158
  render=False,
159
  ),
160
  gr.Slider(