MedicineOCR / app.py
shukdevdatta123's picture
Update app.py
beecb06 verified
raw
history blame
4.07 kB
import gradio as gr
from transformers.image_utils import load_image
from threading import Thread
import time
import torch
from PIL import Image
from transformers import (
Qwen2VLForConditionalGeneration,
AutoProcessor,
TextIteratorStreamer,
)
# ---------------------------
# Helper Functions
# ---------------------------
def progress_bar_html(label: str, primary_color: str = "#4B0082", secondary_color: str = "#9370DB") -> str:
"""
Returns an HTML snippet for a thin animated progress bar with a label.
"""
return f'''
<div style="display: flex; align-items: center;">
<span style="margin-right: 10px; font-size: 14px;">{label}</span>
<div style="width: 110px; height: 5px; background-color: {secondary_color}; border-radius: 2px; overflow: hidden;">
<div style="width: 100%; height: 100%; background-color: {primary_color}; animation: loading 1.5s linear infinite;"></div>
</div>
</div>
<style>
@keyframes loading {{
0% {{ transform: translateX(-100%); }}
100% {{ transform: translateX(100%); }}
}}
</style>
'''
# Model and Processor Setup - CPU version
MODEL_ID = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct"
processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
model = Qwen2VLForConditionalGeneration.from_pretrained(
MODEL_ID,
trust_remote_code=True,
torch_dtype=torch.float32 # Using float32 for CPU compatibility
).to("cpu").eval()
# Main Inference Function
def extract_medicines(image_files):
"""Extract medicine names from prescription images."""
if not image_files:
return "Please upload a prescription image."
images = [load_image(image) for image in image_files]
# Specific prompt to extract only medicine names
text = "Extract ONLY the names of medications/medicines from this prescription image. Format the output as a numbered list of medicine names only, without dosages or instructions."
messages = [{
"role": "user",
"content": [
*[{"type": "image", "image": image} for image in images],
{"type": "text", "text": text},
],
}]
prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = processor(
text=[prompt_full],
images=images,
return_tensors="pt",
padding=True,
).to("cpu")
streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
buffer = ""
yield progress_bar_html("Extracting Medicine Names")
for new_text in streamer:
buffer += new_text
buffer = buffer.replace("<|im_end|>", "")
time.sleep(0.01)
yield buffer
# Gradio Interface
with gr.Blocks() as demo:
gr.Markdown("# Medicine Name Extractor")
gr.Markdown("Upload prescription images to extract medicine names")
with gr.Row():
with gr.Column():
image_input = gr.File(
label="Upload Prescription Image(s)",
file_count="multiple",
file_types=["image"]
)
extract_btn = gr.Button("Extract Medicine Names", variant="primary")
with gr.Column():
output = gr.Markdown(label="Extracted Medicine Names")
extract_btn.click(
fn=extract_medicines,
inputs=image_input,
outputs=output
)
gr.Examples(
examples=[
["examples/prescription1.jpg"],
["examples/prescription2.jpg"],
],
inputs=image_input,
outputs=output,
fn=extract_medicines,
cache_examples=True,
)
gr.Markdown("""
### Notes:
- This app is optimized to run on CPU
- Upload clear images of prescriptions for best results
- Only medicine names will be extracted
""")
demo.queue()
demo.launch(debug=True)