JimmyK300 commited on
Commit
cfee34b
·
verified ·
1 Parent(s): a869700

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -5
app.py CHANGED
@@ -4,13 +4,18 @@ import gradio as gr
4
  import tempfile
5
  import secrets
6
  from pathlib import Path
7
- from transformers import AutoModelForCausalLM, AutoTokenizer, BlipForConditionalGeneration, AutoProcessor
 
8
  from PIL import Image
9
 
10
  # Load Vision-Language Model
11
- vl_model_name = "Salesforce/blip-image-captioning-large"
12
- vl_model = BlipForConditionalGeneration.from_pretrained(vl_model_name)
13
- vl_processor = AutoProcessor.from_pretrained(vl_model_name)
 
 
 
 
14
 
15
  # Load Text Model
16
  model_name = "Qwen/Qwen2.5-Math-1.5B-Instruct"
@@ -31,7 +36,13 @@ def process_image(image, shouldConvert=False):
31
 
32
  # Convert the image to tensor
33
  inputs = vl_processor(images=image, return_tensors="pt")
34
- output = vl_model.generate(**inputs)
 
 
 
 
 
 
35
  description = vl_processor.batch_decode(output, skip_special_tokens=True)[0]
36
 
37
  return f"Math-related content detected: {description}"
 
4
  import tempfile
5
  import secrets
6
  from pathlib import Path
7
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BlipForConditionalGeneration, AutoProcessor, Qwen2VLForConditionalGeneration
8
+ from qwen_vl_utils import process_vision_info
9
  from PIL import Image
10
 
11
  # Load Vision-Language Model
12
+ vl_model = Qwen2VLForConditionalGeneration.from_pretrained(
13
+ "Qwen/Qwen2-VL-2B-Instruct",
14
+ torch_dtype=torch.bfloat16,
15
+ attn_implementation="flash_attention_2",
16
+ device_map="auto",
17
+ )
18
+ vl_processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
19
 
20
  # Load Text Model
21
  model_name = "Qwen/Qwen2.5-Math-1.5B-Instruct"
 
36
 
37
  # Convert the image to tensor
38
  inputs = vl_processor(images=image, return_tensors="pt")
39
+ generated_ids = vl_model.generate(**inputs)
40
+ generated_ids_trimmed = [
41
+ out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
42
+ ]
43
+ output = processor.batch_decode(
44
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
45
+ )
46
  description = vl_processor.batch_decode(output, skip_special_tokens=True)[0]
47
 
48
  return f"Math-related content detected: {description}"