JimmyK300 commited on
Commit
baa7d24
·
verified ·
1 Parent(s): b4c6c06

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -30
app.py CHANGED
@@ -1,7 +1,10 @@
1
  import os
2
  import torch
3
  import gradio as gr
4
- from transformers import AutoModelForCausalLM, AutoTokenizer, AutoProcessor, Qwen2VLForConditionalGeneration
 
 
 
5
  from qwen_vl_utils import process_vision_info
6
  from PIL import Image
7
 
@@ -19,19 +22,30 @@ tokenizer = AutoTokenizer.from_pretrained(model_name)
19
 
20
  math_messages = []
21
 
22
- def process_image(image, should_convert=False):
23
  global math_messages
24
  math_messages = [] # Reset when uploading an image
25
 
26
- if should_convert:
 
 
 
27
  new_img = Image.new('RGB', size=(image.width, image.height), color=(255, 255, 255))
28
  new_img.paste(image, (0, 0), mask=image)
29
  image = new_img
30
 
 
31
  inputs = vl_processor(images=image, return_tensors="pt").to(device)
32
  generated_ids = vl_model.generate(**inputs)
33
- output = vl_processor.batch_decode(generated_ids, skip_special_tokens=True)
34
- return f"Math-related content detected: {output[0]}"
 
 
 
 
 
 
 
35
 
36
  def get_math_response(image_description, user_question):
37
  global math_messages
@@ -41,23 +55,21 @@ def get_math_response(image_description, user_question):
41
  content = f'Image description: {image_description}\n\n' if image_description else ''
42
  query = f"{content}User question: {user_question}"
43
  math_messages.append({'role': 'user', 'content': query})
44
-
45
  model_inputs = tokenizer(query, return_tensors="pt").to(device)
46
  output = model.generate(**model_inputs, max_new_tokens=512)
47
  answer = tokenizer.decode(output[0], skip_special_tokens=True)
48
-
49
  yield answer.replace("\\", "\\\\")
50
  math_messages.append({'role': 'assistant', 'content': answer})
51
 
52
  def math_chat_bot(image, sketchpad, question, state):
53
  current_tab_index = state["tab_index"]
54
  image_description = None
55
-
56
- if current_tab_index == 0 and image is not None:
57
- image_description = process_image(image)
58
- elif current_tab_index == 1 and sketchpad and sketchpad["composite"]:
59
- image_description = process_image(sketchpad["composite"], True)
60
-
61
  yield from get_math_response(image_description, question)
62
 
63
  css = """
@@ -69,41 +81,34 @@ css = """
69
  def tabs_select(e: gr.SelectData, _state):
70
  _state["tab_index"] = e.index
71
 
72
- # Create Gradio UI
73
  with gr.Blocks(css=css) as demo:
74
  gr.HTML("""
75
- <p align="center"><img src="https://modelscope.oss-cn-beijing.aliyuncs.com/resource/qwen.png" style="height: 60px"/></p>
76
- <center><font size=8>📖 Qwen2-Math Demo</font></center>
77
- <center><font size=3>This WebUI is based on Qwen2-VL for OCR and Qwen2-Math for mathematical reasoning. You can input either images or texts of mathematical or arithmetic problems.</font></center>
78
- """)
79
-
80
  state = gr.State({"tab_index": 0})
81
-
82
  with gr.Row():
83
  with gr.Column():
84
  with gr.Tabs() as input_tabs:
85
  with gr.Tab("Upload"):
86
  input_image = gr.Image(type="pil", label="Upload")
87
  with gr.Tab("Sketch"):
88
- input_sketchpad = gr.Sketchpad(label="Sketch", layers=False)
89
-
90
  input_tabs.select(fn=tabs_select, inputs=[state])
91
- input_text = gr.Textbox(label="Input your question")
92
-
93
  with gr.Row():
94
  with gr.Column():
95
  clear_btn = gr.ClearButton([input_image, input_sketchpad, input_text])
96
  with gr.Column():
97
  submit_btn = gr.Button("Submit", variant="primary")
98
-
99
  with gr.Column():
100
- output_md = gr.Markdown(label="Answer", elem_id="qwen-md")
101
-
102
  submit_btn.click(
103
  fn=math_chat_bot,
104
  inputs=[input_image, input_sketchpad, input_text, state],
105
- outputs=output_md
106
- )
107
 
108
  if __name__ == "__main__":
109
- demo.launch()
 
1
  import os
2
  import torch
3
  import gradio as gr
4
+ import tempfile
5
+ import secrets
6
+ from pathlib import Path
7
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BlipForConditionalGeneration, AutoProcessor, Qwen2VLForConditionalGeneration
8
  from qwen_vl_utils import process_vision_info
9
  from PIL import Image
10
 
 
22
 
23
  math_messages = []
24
 
25
+ def process_image(image, shouldConvert=False):
26
  global math_messages
27
  math_messages = [] # Reset when uploading an image
28
 
29
+ if image is None:
30
+ return "No image provided."
31
+
32
+ if shouldConvert:
33
  new_img = Image.new('RGB', size=(image.width, image.height), color=(255, 255, 255))
34
  new_img.paste(image, (0, 0), mask=image)
35
  image = new_img
36
 
37
+ # Convert the image to tensor
38
  inputs = vl_processor(images=image, return_tensors="pt").to(device)
39
  generated_ids = vl_model.generate(**inputs)
40
+ generated_ids_trimmed = [
41
+ out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
42
+ ]
43
+ output = vl_processor.batch_decode(
44
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
45
+ )
46
+ description = output[0] if output else ""
47
+
48
+ return f"Math-related content detected: {description}"
49
 
50
  def get_math_response(image_description, user_question):
51
  global math_messages
 
55
  content = f'Image description: {image_description}\n\n' if image_description else ''
56
  query = f"{content}User question: {user_question}"
57
  math_messages.append({'role': 'user', 'content': query})
 
58
  model_inputs = tokenizer(query, return_tensors="pt").to(device)
59
  output = model.generate(**model_inputs, max_new_tokens=512)
60
  answer = tokenizer.decode(output[0], skip_special_tokens=True)
 
61
  yield answer.replace("\\", "\\\\")
62
  math_messages.append({'role': 'assistant', 'content': answer})
63
 
64
  def math_chat_bot(image, sketchpad, question, state):
65
  current_tab_index = state["tab_index"]
66
  image_description = None
67
+ if current_tab_index == 0:
68
+ if image is not None:
69
+ image_description = process_image(image)
70
+ elif current_tab_index == 1:
71
+ if sketchpad and sketchpad.get("composite"):
72
+ image_description = process_image(sketchpad["composite"], True)
73
  yield from get_math_response(image_description, question)
74
 
75
  css = """
 
81
  def tabs_select(e: gr.SelectData, _state):
82
  _state["tab_index"] = e.index
83
 
84
+ # 创建Gradio接口
85
  with gr.Blocks(css=css) as demo:
86
  gr.HTML("""
87
+ <p align="center"><img src="https://modelscope.oss-cn-beijing.aliyuncs.com/resource/qwen.png" style="height: 60px"/><p>"""
88
+ """<center><font size=8>📖 Qwen2-Math Demo</center>"""
89
+ """
90
+ <center><font size=3>This WebUI is based on Qwen2-VL for OCR and Qwen2-Math for mathematical reasoning. You can input either images or texts of mathematical or arithmetic problems.</center>""")
 
91
  state = gr.State({"tab_index": 0})
 
92
  with gr.Row():
93
  with gr.Column():
94
  with gr.Tabs() as input_tabs:
95
  with gr.Tab("Upload"):
96
  input_image = gr.Image(type="pil", label="Upload")
97
  with gr.Tab("Sketch"):
98
+ input_sketchpad = gr.Sketchpad(type="pil", label="Sketch", layers=False)
 
99
  input_tabs.select(fn=tabs_select, inputs=[state])
100
+ input_text = gr.Textbox(label="input your question")
 
101
  with gr.Row():
102
  with gr.Column():
103
  clear_btn = gr.ClearButton([input_image, input_sketchpad, input_text])
104
  with gr.Column():
105
  submit_btn = gr.Button("Submit", variant="primary")
 
106
  with gr.Column():
107
+ output_md = gr.Markdown(label="answer", elem_id="qwen-md")
 
108
  submit_btn.click(
109
  fn=math_chat_bot,
110
  inputs=[input_image, input_sketchpad, input_text, state],
111
+ outputs=output_md)
 
112
 
113
  if __name__ == "__main__":
114
+ demo.launch()