import gradio as gr from transformers.image_utils import load_image from threading import Thread import time import torch from PIL import Image from transformers import ( Qwen2VLForConditionalGeneration, AutoProcessor, TextIteratorStreamer, ) # --------------------------- # Helper Functions # --------------------------- def progress_bar_html(label: str, primary_color: str = "#4B0082", secondary_color: str = "#9370DB") -> str: """ Returns an HTML snippet for a thin animated progress bar with a label. """ return f'''
{label}
''' # Model and Processor Setup - CPU version MODEL_ID = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct" processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True) model = Qwen2VLForConditionalGeneration.from_pretrained( MODEL_ID, trust_remote_code=True, torch_dtype=torch.float32 # Using float32 for CPU compatibility ).to("cpu").eval() # Main Inference Function def extract_medicines(image_files): """Extract medicine names from prescription images.""" if not image_files: return "Please upload a prescription image." images = [load_image(image) for image in image_files] # Specific prompt to extract only medicine names text = "Extract ONLY the names of medications/medicines from this prescription image. Format the output as a numbered list of medicine names only, without dosages or instructions." messages = [{ "role": "user", "content": [ *[{"type": "image", "image": image} for image in images], {"type": "text", "text": text}, ], }] prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) inputs = processor( text=[prompt_full], images=images, return_tensors="pt", padding=True, ).to("cpu") streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True) generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024) thread = Thread(target=model.generate, kwargs=generation_kwargs) thread.start() buffer = "" yield progress_bar_html("Extracting Medicine Names") for new_text in streamer: buffer += new_text buffer = buffer.replace("<|im_end|>", "") time.sleep(0.01) yield buffer # Gradio Interface with gr.Blocks() as demo: gr.Markdown("# Medicine Name Extractor") gr.Markdown("Upload prescription images to extract medicine names") with gr.Row(): with gr.Column(): image_input = gr.File( label="Upload Prescription Image(s)", file_count="multiple", file_types=["image"] ) extract_btn = gr.Button("Extract Medicine Names", variant="primary") with gr.Column(): output = gr.Markdown(label="Extracted Medicine Names") extract_btn.click( fn=extract_medicines, inputs=image_input, outputs=output ) gr.Examples( examples=[ ["examples/prescription1.jpg"], ["examples/prescription2.jpg"], ], inputs=image_input, outputs=output, fn=extract_medicines, cache_examples=True, ) gr.Markdown(""" ### Notes: - This app is optimized to run on CPU - Upload clear images of prescriptions for best results - Only medicine names will be extracted """) demo.queue() demo.launch(debug=True)