freeCS-dot-org commited on
Commit
82baec6
·
verified ·
1 Parent(s): f779be9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -72
app.py CHANGED
@@ -2,85 +2,71 @@ import os
2
  import time
3
  import spaces
4
  import torch
5
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
6
  import gradio as gr
7
  from threading import Thread
8
 
9
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
10
  MODEL = "AGI-0/Art-v0-3B"
11
 
12
- device = "cuda" # Use "cpu" if no GPU available
13
 
14
- tokenizer = AutoTokenizer.from_pretrained(MODEL)
15
- model = AutoModelForCausalLM.from_pretrained(
16
- MODEL, torch_dtype=torch.bfloat16, device_map="auto"
17
- )
18
- end_of_sentence = tokenizer.convert_tokens_to_ids("<|im_end|>")
19
-
20
- class ConversationManager:
21
- def __init__(self):
22
- self.user_history = [] # User-facing history with formatting
23
- self.model_history = [] # Model-facing history without formatting
24
-
25
- def add_exchange(self, user_message, model_response):
26
- formatted_response = self.format_response(model_response)
27
- self.model_history.append((user_message, model_response))
28
- self.user_history.append((user_message, formatted_response))
29
- print(f"\nModel History Updated: {self.model_history}")
30
- print(f"\nUser History Updated: {self.user_history}")
31
-
32
- def format_response(self, response):
33
- """Format response for UI while keeping raw text for model."""
34
- if "<|end_reasoning|>" in response:
35
- parts = response.split("<|end_reasoning|>")
36
- reasoning, rest = parts[0], parts[1] if len(parts) > 1 else ""
37
- return f"<details><summary>Click to see reasoning</summary>\n\n{reasoning}\n\n</details>\n\n{rest}"
38
- return response
39
 
40
- def get_user_history(self):
41
- return self.user_history
 
 
 
 
 
 
 
 
 
42
 
43
- def get_model_history(self):
44
- return self.model_history
45
 
46
- conversation_manager = ConversationManager()
 
 
 
 
 
47
 
48
  @spaces.GPU()
49
  def stream_chat(
50
- message: str,
51
  history: list,
52
- system_prompt: str,
53
- temperature: float = 0.2,
54
- max_new_tokens: int = 4096,
55
- top_p: float = 1.0,
56
- top_k: int = 1,
57
  penalty: float = 1.1,
58
  ):
59
- print(f'User Message: {message}')
60
-
61
- model_history = conversation_manager.get_model_history()
62
- print(f'Model History: {model_history}')
63
-
64
  conversation = []
65
- for prompt, answer in model_history:
66
  conversation.extend([
67
- {"role": "user", "content": prompt},
68
  {"role": "assistant", "content": answer},
69
  ])
 
70
  conversation.append({"role": "user", "content": message})
71
 
72
- print(f'Formatted Conversation for Model: {conversation}')
 
 
73
 
74
- input_ids = tokenizer.apply_chat_template(
75
- conversation, add_generation_prompt=True, return_tensors="pt"
76
- ).to(model.device)
77
-
78
- streamer = TextIteratorStreamer(
79
- tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True
80
- )
81
-
82
  generate_kwargs = dict(
83
- input_ids=input_ids,
84
  max_new_tokens=max_new_tokens,
85
  do_sample=False if temperature == 0 else True,
86
  top_p=top_p,
@@ -91,35 +77,43 @@ def stream_chat(
91
  streamer=streamer,
92
  )
93
 
94
- buffer = ""
95
- original_response = ""
96
-
97
  with torch.no_grad():
98
  thread = Thread(target=model.generate, kwargs=generate_kwargs)
99
  thread.start()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
- for new_text in streamer:
102
- buffer += new_text
103
- original_response += new_text
104
- print(f'Streaming: {buffer}')
105
- yield conversation_manager.format_response(buffer)
106
 
107
- conversation_manager.add_exchange(message, original_response)
108
 
109
- chatbot = gr.Chatbot(height=600, placeholder="<center><p>Hi! How can I help you today?</p></center>")
110
 
111
- demo = gr.Blocks()
112
- with demo:
113
- gr.HTML("<h2>Link to the model: <a href='https://huggingface.co/AGI-0/Art-v0-3B'>click here</a></h2>")
114
  gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
115
  gr.ChatInterface(
116
  fn=stream_chat,
117
  chatbot=chatbot,
118
  fill_height=True,
119
- additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
120
  additional_inputs=[
121
- gr.Textbox(value="", label="System Prompt", render=False),
122
- gr.Slider(minimum=0, maximum=1, step=0.1, value=0.2, label="Temperature", render=False),
123
  gr.Slider(minimum=128, maximum=8192, step=1, value=4096, label="Max new tokens", render=False),
124
  gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=1.0, label="top_p", render=False),
125
  gr.Slider(minimum=1, maximum=50, step=1, value=1, label="top_k", render=False),
@@ -135,4 +129,4 @@ with demo:
135
  )
136
 
137
  if __name__ == "__main__":
138
- demo.launch()
 
2
  import time
3
  import spaces
4
  import torch
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig
6
  import gradio as gr
7
  from threading import Thread
8
 
9
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
10
  MODEL = "AGI-0/Art-v0-3B"
11
 
12
+ TITLE = """<h2>Link to the model: <a href="https://huggingface.co/AGI-0/Art-v0-3B">click here</a></h2>"""
13
 
14
+ PLACEHOLDER = """
15
+ <center>
16
+ <p>Hi! How can I help you today?</p>
17
+ </center>
18
+ """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
+ CSS = """
21
+ .duplicate-button {
22
+ margin: auto !important;
23
+ color: white !important;
24
+ background: black !important;
25
+ border-radius: 100vh !important;
26
+ }
27
+ h3 {
28
+ text-align: center;
29
+ }
30
+ """
31
 
32
+ device = "cuda" # for GPU usage or "cpu" for CPU usage
 
33
 
34
+ tokenizer = AutoTokenizer.from_pretrained(MODEL)
35
+ model = AutoModelForCausalLM.from_pretrained(
36
+ MODEL,
37
+ torch_dtype=torch.bfloat16,
38
+ device_map="auto")
39
+ end_of_sentence = tokenizer.convert_tokens_to_ids("<|im_end|>")
40
 
41
  @spaces.GPU()
42
  def stream_chat(
43
+ message: str,
44
  history: list,
45
+ system_prompt: str = "", # Include system prompt
46
+ temperature: float = 0.2,
47
+ max_new_tokens: int = 4096,
48
+ top_p: float = 1.0,
49
+ top_k: int = 1,
50
  penalty: float = 1.1,
51
  ):
52
+ print(f'message: {message}')
53
+ print(f'history: {history}')
54
+
 
 
55
  conversation = []
56
+ for prompt, answer in history:
57
  conversation.extend([
58
+ {"role": "user", "content": prompt},
59
  {"role": "assistant", "content": answer},
60
  ])
61
+
62
  conversation.append({"role": "user", "content": message})
63
 
64
+ input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt").to(model.device)
65
+
66
+ streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
67
 
 
 
 
 
 
 
 
 
68
  generate_kwargs = dict(
69
+ input_ids=input_ids,
70
  max_new_tokens=max_new_tokens,
71
  do_sample=False if temperature == 0 else True,
72
  top_p=top_p,
 
77
  streamer=streamer,
78
  )
79
 
 
 
 
80
  with torch.no_grad():
81
  thread = Thread(target=model.generate, kwargs=generate_kwargs)
82
  thread.start()
83
+
84
+ buffer = ""
85
+ user_buffer = ""
86
+ found_token = False
87
+
88
+ for new_text in streamer:
89
+ buffer += new_text
90
+ user_buffer += new_text
91
+
92
+ if "<|end_reasoning|>" in user_buffer and not found_token:
93
+ parts = user_buffer.split("<|end_reasoning|>")
94
+ reasoning = parts[0]
95
+ rest = parts[1] if len(parts) > 1 else ""
96
+ user_buffer = f"<details><summary>Click to see reasoning</summary>\n\n{reasoning}\n\n</details>\n\n{rest}"
97
+ found_token = True
98
+
99
+ yield user_buffer
100
 
101
+ history.append((message, buffer)) # Crucial: Append the original buffer
 
 
 
 
102
 
 
103
 
104
+ chatbot = gr.Chatbot(height=600, placeholder=PLACEHOLDER)
105
 
106
+ with gr.Blocks(css=CSS, theme="soft") as demo:
107
+ gr.HTML(TITLE)
 
108
  gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
109
  gr.ChatInterface(
110
  fn=stream_chat,
111
  chatbot=chatbot,
112
  fill_height=True,
113
+ additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False), # Added system prompt textbox here
114
  additional_inputs=[
115
+ gr.Textbox(value="", label="System Prompt", lines=2, render=False), # Added system prompt textbox
116
+ gr.Slider(minimum=0, maximum=1, step=0.1, value=0.2, label="Temperature", render=False),
117
  gr.Slider(minimum=128, maximum=8192, step=1, value=4096, label="Max new tokens", render=False),
118
  gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=1.0, label="top_p", render=False),
119
  gr.Slider(minimum=1, maximum=50, step=1, value=1, label="top_k", render=False),
 
129
  )
130
 
131
  if __name__ == "__main__":
132
+ demo.launch()