freeCS-dot-org commited on
Commit
2fb89d3
·
verified ·
1 Parent(s): 82baec6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +72 -64
app.py CHANGED
@@ -2,71 +2,85 @@ import os
2
  import time
3
  import spaces
4
  import torch
5
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig
6
  import gradio as gr
7
  from threading import Thread
8
 
9
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
10
  MODEL = "AGI-0/Art-v0-3B"
11
 
12
- TITLE = """<h2>Link to the model: <a href="https://huggingface.co/AGI-0/Art-v0-3B">click here</a></h2>"""
13
-
14
- PLACEHOLDER = """
15
- <center>
16
- <p>Hi! How can I help you today?</p>
17
- </center>
18
- """
19
-
20
- CSS = """
21
- .duplicate-button {
22
- margin: auto !important;
23
- color: white !important;
24
- background: black !important;
25
- border-radius: 100vh !important;
26
- }
27
- h3 {
28
- text-align: center;
29
- }
30
- """
31
-
32
- device = "cuda" # for GPU usage or "cpu" for CPU usage
33
 
34
  tokenizer = AutoTokenizer.from_pretrained(MODEL)
35
  model = AutoModelForCausalLM.from_pretrained(
36
- MODEL,
37
- torch_dtype=torch.bfloat16,
38
- device_map="auto")
39
  end_of_sentence = tokenizer.convert_tokens_to_ids("<|im_end|>")
40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  @spaces.GPU()
42
  def stream_chat(
43
- message: str,
44
  history: list,
45
- system_prompt: str = "", # Include system prompt
46
- temperature: float = 0.2,
47
- max_new_tokens: int = 4096,
48
- top_p: float = 1.0,
49
- top_k: int = 1,
50
  penalty: float = 1.1,
51
  ):
52
- print(f'message: {message}')
53
- print(f'history: {history}')
54
-
 
 
55
  conversation = []
56
- for prompt, answer in history:
57
  conversation.extend([
58
- {"role": "user", "content": prompt},
59
  {"role": "assistant", "content": answer},
60
  ])
61
-
62
  conversation.append({"role": "user", "content": message})
63
 
64
- input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt").to(model.device)
65
-
66
- streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
67
 
 
 
 
 
 
 
 
 
68
  generate_kwargs = dict(
69
- input_ids=input_ids,
70
  max_new_tokens=max_new_tokens,
71
  do_sample=False if temperature == 0 else True,
72
  top_p=top_p,
@@ -77,43 +91,37 @@ def stream_chat(
77
  streamer=streamer,
78
  )
79
 
 
 
 
80
  with torch.no_grad():
81
  thread = Thread(target=model.generate, kwargs=generate_kwargs)
82
  thread.start()
83
-
84
- buffer = ""
85
- user_buffer = ""
86
- found_token = False
87
-
88
- for new_text in streamer:
89
- buffer += new_text
90
- user_buffer += new_text
91
 
92
- if "<|end_reasoning|>" in user_buffer and not found_token:
93
- parts = user_buffer.split("<|end_reasoning|>")
94
- reasoning = parts[0]
95
- rest = parts[1] if len(parts) > 1 else ""
96
- user_buffer = f"<details><summary>Click to see reasoning</summary>\n\n{reasoning}\n\n</details>\n\n{rest}"
97
- found_token = True
98
-
99
- yield user_buffer
100
 
101
- history.append((message, buffer)) # Crucial: Append the original buffer
102
 
 
103
 
104
- chatbot = gr.Chatbot(height=600, placeholder=PLACEHOLDER)
105
 
106
- with gr.Blocks(css=CSS, theme="soft") as demo:
107
- gr.HTML(TITLE)
 
108
  gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
109
  gr.ChatInterface(
110
  fn=stream_chat,
111
  chatbot=chatbot,
112
  fill_height=True,
113
- additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False), # Added system prompt textbox here
114
  additional_inputs=[
115
- gr.Textbox(value="", label="System Prompt", lines=2, render=False), # Added system prompt textbox
116
- gr.Slider(minimum=0, maximum=1, step=0.1, value=0.2, label="Temperature", render=False),
117
  gr.Slider(minimum=128, maximum=8192, step=1, value=4096, label="Max new tokens", render=False),
118
  gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=1.0, label="top_p", render=False),
119
  gr.Slider(minimum=1, maximum=50, step=1, value=1, label="top_k", render=False),
 
2
  import time
3
  import spaces
4
  import torch
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
6
  import gradio as gr
7
  from threading import Thread
8
 
9
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
10
  MODEL = "AGI-0/Art-v0-3B"
11
 
12
+ device = "cuda" # Use "cpu" if no GPU available
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  tokenizer = AutoTokenizer.from_pretrained(MODEL)
15
  model = AutoModelForCausalLM.from_pretrained(
16
+ MODEL, torch_dtype=torch.bfloat16, device_map="auto"
17
+ )
 
18
  end_of_sentence = tokenizer.convert_tokens_to_ids("<|im_end|>")
19
 
20
+ class ConversationManager:
21
+ def __init__(self):
22
+ self.user_history = [] # User-facing history with formatting
23
+ self.model_history = [] # Model-facing history without formatting
24
+
25
+ def add_exchange(self, user_message, model_response):
26
+ formatted_response = self.format_response(model_response)
27
+ self.model_history.append((user_message, model_response))
28
+ self.user_history.append((user_message, formatted_response))
29
+ print(f"\nModel History Updated: {self.model_history}")
30
+ print(f"\nUser History Updated: {self.user_history}")
31
+
32
+ def format_response(self, response):
33
+ """Format response for UI while keeping raw text for model."""
34
+ if "<|end_reasoning|>" in response:
35
+ parts = response.split("<|end_reasoning|>")
36
+ reasoning, rest = parts[0], parts[1] if len(parts) > 1 else ""
37
+ return f"<details><summary>Click to see reasoning</summary>\n\n{reasoning}\n\n</details>\n\n{rest}"
38
+ return response
39
+
40
+ def get_user_history(self):
41
+ return self.user_history
42
+
43
+ def get_model_history(self):
44
+ return self.model_history
45
+
46
+ conversation_manager = ConversationManager()
47
+
48
  @spaces.GPU()
49
  def stream_chat(
50
+ message: str,
51
  history: list,
52
+ system_prompt: str,
53
+ temperature: float = 0.2,
54
+ max_new_tokens: int = 4096,
55
+ top_p: float = 1.0,
56
+ top_k: int = 1,
57
  penalty: float = 1.1,
58
  ):
59
+ print(f'User Message: {message}')
60
+
61
+ model_history = conversation_manager.get_model_history()
62
+ print(f'Model History: {model_history}')
63
+
64
  conversation = []
65
+ for prompt, answer in model_history:
66
  conversation.extend([
67
+ {"role": "user", "content": prompt},
68
  {"role": "assistant", "content": answer},
69
  ])
 
70
  conversation.append({"role": "user", "content": message})
71
 
72
+ print(f'Formatted Conversation for Model: {conversation}')
 
 
73
 
74
+ input_ids = tokenizer.apply_chat_template(
75
+ conversation, add_generation_prompt=True, return_tensors="pt"
76
+ ).to(model.device)
77
+
78
+ streamer = TextIteratorStreamer(
79
+ tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True
80
+ )
81
+
82
  generate_kwargs = dict(
83
+ input_ids=input_ids,
84
  max_new_tokens=max_new_tokens,
85
  do_sample=False if temperature == 0 else True,
86
  top_p=top_p,
 
91
  streamer=streamer,
92
  )
93
 
94
+ buffer = ""
95
+ original_response = ""
96
+
97
  with torch.no_grad():
98
  thread = Thread(target=model.generate, kwargs=generate_kwargs)
99
  thread.start()
 
 
 
 
 
 
 
 
100
 
101
+ for new_text in streamer:
102
+ buffer += new_text
103
+ original_response += new_text
104
+ print(f'Streaming: {buffer}')
105
+ formatted_buffer = conversation_manager.format_response(buffer)
106
+ yield formatted_buffer, history + [[message, formatted_buffer]]
 
 
107
 
 
108
 
109
+ conversation_manager.add_exchange(message, original_response)
110
 
111
+ chatbot = gr.Chatbot(height=600, placeholder="<center><p>Hi! How can I help you today?</p></center>")
112
 
113
+ demo = gr.Blocks()
114
+ with demo:
115
+ gr.HTML("<h2>Link to the model: <a href='https://huggingface.co/AGI-0/Art-v0-3B'>click here</a></h2>")
116
  gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
117
  gr.ChatInterface(
118
  fn=stream_chat,
119
  chatbot=chatbot,
120
  fill_height=True,
121
+ additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
122
  additional_inputs=[
123
+ gr.Textbox(value="", label="System Prompt", render=False),
124
+ gr.Slider(minimum=0, maximum=1, step=0.1, value=0.2, label="Temperature", render=False),
125
  gr.Slider(minimum=128, maximum=8192, step=1, value=4096, label="Max new tokens", render=False),
126
  gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=1.0, label="top_p", render=False),
127
  gr.Slider(minimum=1, maximum=50, step=1, value=1, label="top_k", render=False),