Spaces:
Running
Running
import gradio as gr | |
from transformers.image_utils import load_image | |
from threading import Thread | |
import time | |
import torch | |
from PIL import Image | |
from transformers import ( | |
Qwen2VLForConditionalGeneration, | |
AutoProcessor, | |
TextIteratorStreamer, | |
) | |
# --------------------------- | |
# Helper Functions | |
# --------------------------- | |
def progress_bar_html(label: str, primary_color: str = "#4B0082", secondary_color: str = "#9370DB") -> str: | |
""" | |
Returns an HTML snippet for a thin animated progress bar with a label. | |
""" | |
return f''' | |
<div style="display: flex; align-items: center;"> | |
<span style="margin-right: 10px; font-size: 14px;">{label}</span> | |
<div style="width: 110px; height: 5px; background-color: {secondary_color}; border-radius: 2px; overflow: hidden;"> | |
<div style="width: 100%; height: 100%; background-color: {primary_color}; animation: loading 1.5s linear infinite;"></div> | |
</div> | |
</div> | |
<style> | |
@keyframes loading {{ | |
0% {{ transform: translateX(-100%); }} | |
100% {{ transform: translateX(100%); }} | |
}} | |
</style> | |
''' | |
# Model and Processor Setup - CPU version | |
MODEL_ID = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct" | |
processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True) | |
model = Qwen2VLForConditionalGeneration.from_pretrained( | |
MODEL_ID, | |
trust_remote_code=True, | |
torch_dtype=torch.float32 # Using float32 for CPU compatibility | |
).to("cpu").eval() | |
# Main Inference Function | |
def extract_medicines(image_files): | |
"""Extract medicine names from prescription images.""" | |
if not image_files: | |
return "Please upload a prescription image." | |
images = [load_image(image) for image in image_files] | |
# Specific prompt to extract only medicine names | |
text = "Extract ONLY the names of medications/medicines from this prescription image. Format the output as a numbered list of medicine names only, without dosages or instructions." | |
messages = [{ | |
"role": "user", | |
"content": [ | |
*[{"type": "image", "image": image} for image in images], | |
{"type": "text", "text": text}, | |
], | |
}] | |
prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
inputs = processor( | |
text=[prompt_full], | |
images=images, | |
return_tensors="pt", | |
padding=True, | |
).to("cpu") | |
streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True) | |
generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024) | |
thread = Thread(target=model.generate, kwargs=generation_kwargs) | |
thread.start() | |
buffer = "" | |
yield progress_bar_html("Extracting Medicine Names") | |
for new_text in streamer: | |
buffer += new_text | |
buffer = buffer.replace("<|im_end|>", "") | |
time.sleep(0.01) | |
yield buffer | |
# Gradio Interface | |
with gr.Blocks() as demo: | |
gr.Markdown("# Medicine Name Extractor") | |
gr.Markdown("Upload prescription images to extract medicine names") | |
with gr.Row(): | |
with gr.Column(): | |
image_input = gr.File( | |
label="Upload Prescription Image(s)", | |
file_count="multiple", | |
file_types=["image"] | |
) | |
extract_btn = gr.Button("Extract Medicine Names", variant="primary") | |
with gr.Column(): | |
output = gr.Markdown(label="Extracted Medicine Names") | |
extract_btn.click( | |
fn=extract_medicines, | |
inputs=image_input, | |
outputs=output | |
) | |
gr.Examples( | |
examples=[ | |
["examples/prescription1.jpg"], | |
["examples/prescription2.jpg"], | |
], | |
inputs=image_input, | |
outputs=output, | |
fn=extract_medicines, | |
cache_examples=True, | |
) | |
gr.Markdown(""" | |
### Notes: | |
- This app is optimized to run on CPU | |
- Upload clear images of prescriptions for best results | |
- Only medicine names will be extracted | |
""") | |
demo.queue() | |
demo.launch(debug=True) |