| |
| import warnings |
|
|
| |
| warnings.filterwarnings("ignore", category=FutureWarning, module="spaces") |
|
|
| import base64 |
| import os |
| import re |
| import subprocess |
| import sys |
| import threading |
| import time |
| from collections import OrderedDict |
| from io import BytesIO |
|
|
| import gradio as gr |
| import pypdfium2 as pdfium |
| import spaces |
| import torch |
| from openai import OpenAI |
| from PIL import Image |
| from transformers import ( |
| LightOnOcrForConditionalGeneration, |
| LightOnOcrProcessor, |
| TextIteratorStreamer, |
| ) |
|
|
| |
| VLLM_ENDPOINT_OCR = os.environ.get("VLLM_ENDPOINT_OCR") |
| VLLM_ENDPOINT_BBOX = os.environ.get("VLLM_ENDPOINT_BBOX") |
|
|
| |
| STREAM_YIELD_INTERVAL = 0.5 |
|
|
| |
| MODEL_REGISTRY = { |
| "LightOnOCR-2-1B (Best OCR)": { |
| "model_id": "lightonai/LightOnOCR-2-1B", |
| "has_bbox": False, |
| "description": "Best overall OCR performance", |
| "vllm_endpoint": VLLM_ENDPOINT_OCR, |
| }, |
| "LightOnOCR-2-1B-bbox (Best Bbox)": { |
| "model_id": "lightonai/LightOnOCR-2-1B-bbox", |
| "has_bbox": True, |
| "description": "Best bounding box detection", |
| "vllm_endpoint": VLLM_ENDPOINT_BBOX, |
| }, |
| "LightOnOCR-2-1B-base": { |
| "model_id": "lightonai/LightOnOCR-2-1B-base", |
| "has_bbox": False, |
| "description": "Base OCR model", |
| }, |
| "LightOnOCR-2-1B-bbox-base": { |
| "model_id": "lightonai/LightOnOCR-2-1B-bbox-base", |
| "has_bbox": True, |
| "description": "Base bounding box model", |
| }, |
| "LightOnOCR-2-1B-ocr-soup": { |
| "model_id": "lightonai/LightOnOCR-2-1B-ocr-soup", |
| "has_bbox": False, |
| "description": "OCR soup variant", |
| }, |
| "LightOnOCR-2-1B-bbox-soup": { |
| "model_id": "lightonai/LightOnOCR-2-1B-bbox-soup", |
| "has_bbox": True, |
| "description": "Bounding box soup variant", |
| }, |
| } |
|
|
| DEFAULT_MODEL = "LightOnOCR-2-1B (Best OCR)" |
|
|
| device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| |
| if device == "cuda": |
| attn_implementation = "sdpa" |
| dtype = torch.bfloat16 |
| print("Using sdpa for GPU") |
| else: |
| attn_implementation = "eager" |
| dtype = torch.float32 |
| print("Using eager attention for CPU") |
|
|
|
|
| class ModelManager: |
| """Manages model loading with LRU caching and GPU memory management.""" |
|
|
| def __init__(self, max_cached=2): |
| self._cache = OrderedDict() |
| self._max_cached = max_cached |
|
|
| def get_model(self, model_name): |
| """Get model and processor, loading if necessary.""" |
| config = MODEL_REGISTRY.get(model_name) |
| if config is None: |
| raise ValueError(f"Unknown model: {model_name}") |
|
|
| model_id = config["model_id"] |
|
|
| |
| if model_id in self._cache: |
| |
| self._cache.move_to_end(model_id) |
| print(f"Using cached model: {model_name}") |
| return self._cache[model_id] |
|
|
| |
| while len(self._cache) >= self._max_cached: |
| evicted_id, (evicted_model, _) = self._cache.popitem(last=False) |
| print(f"Evicting model from cache: {evicted_id}") |
| del evicted_model |
| if device == "cuda": |
| torch.cuda.empty_cache() |
|
|
| |
| print(f"Loading model: {model_name} ({model_id})...") |
| model = ( |
| LightOnOcrForConditionalGeneration.from_pretrained( |
| model_id, |
| attn_implementation=attn_implementation, |
| torch_dtype=dtype, |
| trust_remote_code=True, |
| ) |
| .to(device) |
| .eval() |
| ) |
|
|
| processor = LightOnOcrProcessor.from_pretrained( |
| model_id, trust_remote_code=True |
| ) |
|
|
| |
| self._cache[model_id] = (model, processor) |
| print(f"Model loaded successfully: {model_name}") |
|
|
| return model, processor |
|
|
| def get_model_info(self, model_name): |
| """Get model info without loading.""" |
| return MODEL_REGISTRY.get(model_name) |
|
|
|
|
| |
| model_manager = ModelManager(max_cached=2) |
| print("Model manager initialized. Models will be loaded on first use.") |
|
|
|
|
| def render_pdf_page(page, max_resolution=1540, scale=2.77): |
| """Render a PDF page to PIL Image.""" |
| width, height = page.get_size() |
| pixel_width = width * scale |
| pixel_height = height * scale |
| resize_factor = min(1, max_resolution / pixel_width, max_resolution / pixel_height) |
| target_scale = scale * resize_factor |
| return page.render(scale=target_scale, rev_byteorder=True).to_pil() |
|
|
|
|
| def process_pdf(pdf_path, page_num=1): |
| """Extract a specific page from PDF.""" |
| pdf = pdfium.PdfDocument(pdf_path) |
| total_pages = len(pdf) |
| page_idx = min(max(int(page_num) - 1, 0), total_pages - 1) |
|
|
| page = pdf[page_idx] |
| img = render_pdf_page(page) |
|
|
| pdf.close() |
| return img, total_pages, page_idx + 1 |
|
|
|
|
| def clean_output_text(text): |
| """Remove chat template artifacts from output.""" |
| |
| markers_to_remove = ["system", "user", "assistant"] |
|
|
| |
| lines = text.split("\n") |
| cleaned_lines = [] |
|
|
| for line in lines: |
| stripped = line.strip() |
| |
| if stripped.lower() not in markers_to_remove: |
| cleaned_lines.append(line) |
|
|
| |
| cleaned = "\n".join(cleaned_lines).strip() |
|
|
| |
| if "assistant" in text.lower(): |
| parts = text.split("assistant", 1) |
| if len(parts) > 1: |
| cleaned = parts[1].strip() |
|
|
| return cleaned |
|
|
|
|
| |
| BBOX_PATTERN = r"!\[image\]\((image_\d+\.png)\)\s*(\d+),(\d+),(\d+),(\d+)" |
|
|
|
|
| def parse_bbox_output(text): |
| """Parse bbox output and return cleaned text with list of detections.""" |
| detections = [] |
| for match in re.finditer(BBOX_PATTERN, text): |
| image_ref, x1, y1, x2, y2 = match.groups() |
| detections.append( |
| {"ref": image_ref, "coords": (int(x1), int(y1), int(x2), int(y2))} |
| ) |
| |
| cleaned = re.sub(BBOX_PATTERN, r"", text) |
| return cleaned, detections |
|
|
|
|
| def crop_from_bbox(source_image, bbox, padding=5): |
| """Crop region from image based on normalized [0,1000] coords.""" |
| w, h = source_image.size |
| x1, y1, x2, y2 = bbox["coords"] |
|
|
| |
| px1 = int(x1 * w / 1000) |
| py1 = int(y1 * h / 1000) |
| px2 = int(x2 * w / 1000) |
| py2 = int(y2 * h / 1000) |
|
|
| |
| px1, py1 = max(0, px1 - padding), max(0, py1 - padding) |
| px2, py2 = min(w, px2 + padding), min(h, py2 + padding) |
|
|
| return source_image.crop((px1, py1, px2, py2)) |
|
|
|
|
| def image_to_data_uri(image): |
| """Convert PIL image to base64 data URI for markdown embedding.""" |
| buffer = BytesIO() |
| image.save(buffer, format="PNG") |
| b64 = base64.b64encode(buffer.getvalue()).decode() |
| return f"data:image/png;base64,{b64}" |
|
|
|
|
| def extract_text_via_vllm(image, model_name, temperature=0.2, stream=False, max_tokens=2048): |
| """Extract text from image using vLLM endpoint.""" |
| config = MODEL_REGISTRY.get(model_name) |
| if config is None: |
| raise ValueError(f"Unknown model: {model_name}") |
|
|
| endpoint = config.get("vllm_endpoint") |
| if endpoint is None: |
| raise ValueError(f"Model {model_name} does not have a vLLM endpoint") |
|
|
| model_id = config["model_id"] |
|
|
| |
| if isinstance(image, Image.Image): |
| image_uri = image_to_data_uri(image) |
| else: |
| |
| image_uri = image |
|
|
| |
| client = OpenAI(base_url=endpoint, api_key="not-needed") |
|
|
| |
| messages = [ |
| { |
| "role": "user", |
| "content": [ |
| {"type": "image_url", "image_url": {"url": image_uri}}, |
| ], |
| } |
| ] |
|
|
| if stream: |
| |
| response = client.chat.completions.create( |
| model=model_id, |
| messages=messages, |
| max_tokens=max_tokens, |
| temperature=temperature if temperature > 0 else 0.0, |
| top_p=0.9, |
| stream=True, |
| ) |
|
|
| full_text = "" |
| last_yield_time = time.time() |
| for chunk in response: |
| if chunk.choices and chunk.choices[0].delta.content: |
| full_text += chunk.choices[0].delta.content |
| |
| if time.time() - last_yield_time > STREAM_YIELD_INTERVAL: |
| yield clean_output_text(full_text) |
| last_yield_time = time.time() |
| |
| yield clean_output_text(full_text) |
| else: |
| |
| response = client.chat.completions.create( |
| model=model_id, |
| messages=messages, |
| max_tokens=max_tokens, |
| temperature=temperature if temperature > 0 else 0.0, |
| top_p=0.9, |
| stream=False, |
| ) |
|
|
| output_text = response.choices[0].message.content |
| cleaned_text = clean_output_text(output_text) |
| yield cleaned_text |
|
|
|
|
| def render_bbox_with_crops(raw_output, source_image): |
| """Replace markdown image placeholders with actual cropped images.""" |
| cleaned, detections = parse_bbox_output(raw_output) |
|
|
| for bbox in detections: |
| try: |
| cropped = crop_from_bbox(source_image, bbox) |
| data_uri = image_to_data_uri(cropped) |
| |
| cleaned = cleaned.replace( |
| f"", f"" |
| ) |
| except Exception as e: |
| print(f"Error cropping bbox {bbox}: {e}") |
| |
| continue |
|
|
| return cleaned |
|
|
|
|
| @spaces.GPU |
| def extract_text_from_image(image, model_name, temperature=0.2, stream=False, max_tokens=2048): |
| """Extract text from image using LightOnOCR model.""" |
| |
| config = MODEL_REGISTRY.get(model_name, {}) |
| if config.get("vllm_endpoint"): |
| |
| yield from extract_text_via_vllm(image, model_name, temperature, stream, max_tokens) |
| return |
|
|
| |
| model, processor = model_manager.get_model(model_name) |
|
|
| |
| chat = [ |
| { |
| "role": "user", |
| "content": [ |
| {"type": "image", "url": image}, |
| ], |
| } |
| ] |
|
|
| |
| inputs = processor.apply_chat_template( |
| chat, |
| add_generation_prompt=True, |
| tokenize=True, |
| return_dict=True, |
| return_tensors="pt", |
| ) |
|
|
| |
| inputs = { |
| k: v.to(device=device, dtype=dtype) |
| if isinstance(v, torch.Tensor) |
| and v.dtype in [torch.float32, torch.float16, torch.bfloat16] |
| else v.to(device) |
| if isinstance(v, torch.Tensor) |
| else v |
| for k, v in inputs.items() |
| } |
|
|
| generation_kwargs = dict( |
| **inputs, |
| max_new_tokens=max_tokens, |
| temperature=temperature if temperature > 0 else 0.0, |
| top_p=0.9, |
| top_k=0, |
| use_cache=True, |
| do_sample=temperature > 0, |
| ) |
|
|
| if stream: |
| |
| streamer = TextIteratorStreamer( |
| processor.tokenizer, skip_prompt=True, skip_special_tokens=True |
| ) |
| generation_kwargs["streamer"] = streamer |
|
|
| |
| thread = threading.Thread(target=model.generate, kwargs=generation_kwargs) |
| thread.start() |
|
|
| |
| full_text = "" |
| last_yield_time = time.time() |
| for new_text in streamer: |
| full_text += new_text |
| |
| if time.time() - last_yield_time > STREAM_YIELD_INTERVAL: |
| yield clean_output_text(full_text) |
| last_yield_time = time.time() |
|
|
| thread.join() |
| |
| yield clean_output_text(full_text) |
| else: |
| |
| with torch.no_grad(): |
| outputs = model.generate(**generation_kwargs) |
|
|
| |
| output_text = processor.decode(outputs[0], skip_special_tokens=True) |
|
|
| |
| cleaned_text = clean_output_text(output_text) |
|
|
| yield cleaned_text |
|
|
|
|
| def process_input(file_input, model_name, temperature, page_num, enable_streaming, max_output_tokens): |
| """Process uploaded file (image or PDF) and extract text with optional streaming.""" |
| if file_input is None: |
| yield "Please upload an image or PDF first.", "", "", None, gr.update() |
| return |
|
|
| image_to_process = None |
| page_info = "" |
|
|
| file_path = file_input if isinstance(file_input, str) else file_input.name |
|
|
| |
| if file_path.lower().endswith(".pdf"): |
| try: |
| image_to_process, total_pages, actual_page = process_pdf( |
| file_path, int(page_num) |
| ) |
| page_info = f"Processing page {actual_page} of {total_pages}" |
| except Exception as e: |
| yield f"Error processing PDF: {str(e)}", "", "", None, gr.update() |
| return |
| |
| else: |
| try: |
| image_to_process = Image.open(file_path) |
| page_info = "Processing image" |
| except Exception as e: |
| yield f"Error opening image: {str(e)}", "", "", None, gr.update() |
| return |
|
|
| |
| model_info = MODEL_REGISTRY.get(model_name, {}) |
| has_bbox = model_info.get("has_bbox", False) |
|
|
| try: |
| |
| for extracted_text in extract_text_from_image( |
| image_to_process, model_name, temperature, stream=enable_streaming, max_tokens=max_output_tokens |
| ): |
| |
| if has_bbox: |
| rendered_text = render_bbox_with_crops(extracted_text, image_to_process) |
| else: |
| rendered_text = extracted_text |
| yield ( |
| rendered_text, |
| extracted_text, |
| page_info, |
| image_to_process, |
| gr.update(), |
| ) |
|
|
| except Exception as e: |
| error_msg = f"Error during text extraction: {str(e)}" |
| yield error_msg, error_msg, page_info, image_to_process, gr.update() |
|
|
|
|
| def update_slider_and_preview(file_input): |
| """Update page slider and preview image based on uploaded file.""" |
| if file_input is None: |
| return gr.update(maximum=20, value=1), None |
|
|
| file_path = file_input if isinstance(file_input, str) else file_input.name |
|
|
| if file_path.lower().endswith(".pdf"): |
| try: |
| pdf = pdfium.PdfDocument(file_path) |
| total_pages = len(pdf) |
| |
| page = pdf[0] |
| preview_image = page.render(scale=2).to_pil() |
| pdf.close() |
| return gr.update(maximum=total_pages, value=1), preview_image |
| except: |
| return gr.update(maximum=20, value=1), None |
| else: |
| |
| try: |
| preview_image = Image.open(file_path) |
| return gr.update(maximum=1, value=1), preview_image |
| except: |
| return gr.update(maximum=1, value=1), None |
|
|
|
|
| |
| def get_model_info_text(model_name): |
| """Return formatted model info string.""" |
| info = MODEL_REGISTRY.get(model_name, {}) |
| has_bbox = ( |
| "Yes - will show cropped regions inline" |
| if info.get("has_bbox", False) |
| else "No" |
| ) |
| return f"**Description:** {info.get('description', 'N/A')}\n**Bounding Box Detection:** {has_bbox}" |
|
|
|
|
| |
| with gr.Blocks(title="LightOnOCR-2 Multi-Model OCR") as demo: |
| gr.Markdown(f""" |
| # LightOnOCR-2 — Efficient 1B VLM for OCR |
| |
| State-of-the-art OCR on OlmOCR-Bench, ~9× smaller and faster than competitors. Handles tables, forms, math, multi-column layouts. |
| |
| ⚡ **3.3× faster** than Chandra, **1.7× faster** than OlmOCR | 💸 **<$0.01/1k pages** | 🧠 End-to-end differentiable | 📍 Bbox variants for image detection |
| |
| 📄 [Paper](https://huggingface.co/papers/lightonocr-2) | 📝 [Blog](https://huggingface.co/blog/lightonai/lightonocr-2) | 📊 [Dataset](https://huggingface.co/datasets/lightonai/LightOnOCR-mix-0126) | 📓 [Finetuning](https://colab.research.google.com/drive/1WjbsFJZ4vOAAlKtcCauFLn_evo5UBRNa?usp=sharing) |
| |
| --- |
| |
| **How to use:** Select a model → Upload image/PDF → Click "Extract Text" | **Device:** {device.upper()} | **Attention:** {attn_implementation} |
| """) |
|
|
| with gr.Row(): |
| with gr.Column(scale=1): |
| model_selector = gr.Dropdown( |
| choices=list(MODEL_REGISTRY.keys()), |
| value=DEFAULT_MODEL, |
| label="Model", |
| info="Select OCR model variant", |
| ) |
| model_info = gr.Markdown( |
| value=get_model_info_text(DEFAULT_MODEL), label="Model Info" |
| ) |
| file_input = gr.File( |
| label="Upload Image or PDF", |
| file_types=[".pdf", ".png", ".jpg", ".jpeg"], |
| type="filepath", |
| ) |
| rendered_image = gr.Image( |
| label="Preview", type="pil", height=400, interactive=False |
| ) |
| num_pages = gr.Slider( |
| minimum=1, |
| maximum=20, |
| value=1, |
| step=1, |
| label="PDF: Page Number", |
| info="Select which page to extract", |
| ) |
| page_info = gr.Textbox(label="Processing Info", value="", interactive=False) |
| temperature = gr.Slider( |
| minimum=0.0, |
| maximum=1.0, |
| value=0.2, |
| step=0.05, |
| label="Temperature", |
| info="0.0 = deterministic, Higher = more varied", |
| ) |
| enable_streaming = gr.Checkbox( |
| label="Enable Streaming", |
| value=True, |
| info="Show text progressively as it's generated", |
| ) |
| max_output_tokens = gr.Slider( |
| minimum=256, |
| maximum=8192, |
| value=2048, |
| step=256, |
| label="Max Output Tokens", |
| info="Maximum number of tokens to generate", |
| ) |
| submit_btn = gr.Button("Extract Text", variant="primary") |
| clear_btn = gr.Button("Clear", variant="secondary") |
|
|
| with gr.Column(scale=2): |
| output_text = gr.Markdown( |
| label="📄 Extracted Text (Rendered)", |
| value="*Extracted text will appear here...*", |
| latex_delimiters=[ |
| {"left": "$$", "right": "$$", "display": True}, |
| {"left": "$", "right": "$", "display": False}, |
| ], |
| ) |
|
|
| |
| EXAMPLE_IMAGES = [ |
| "examples/example_1.png", |
| "examples/example_2.png", |
| "examples/example_3.png", |
| "examples/example_4.png", |
| "examples/example_5.png", |
| "examples/example_6.png", |
| "examples/example_7.png", |
| "examples/example_8.png", |
| "examples/example_9.png", |
| ] |
|
|
| with gr.Accordion("📁 Example Documents (click an image to load)", open=True): |
| example_gallery = gr.Gallery( |
| value=EXAMPLE_IMAGES, |
| columns=5, |
| rows=2, |
| height="auto", |
| object_fit="contain", |
| show_label=False, |
| allow_preview=False, |
| ) |
|
|
| def load_example_image(evt: gr.SelectData): |
| """Load selected example image into file input.""" |
| return EXAMPLE_IMAGES[evt.index] |
|
|
| example_gallery.select( |
| fn=load_example_image, |
| outputs=[file_input], |
| ) |
|
|
| with gr.Row(): |
| with gr.Column(): |
| raw_output = gr.Textbox( |
| label="Raw Markdown Output", |
| placeholder="Raw text will appear here...", |
| lines=20, |
| max_lines=30, |
| ) |
|
|
| |
| submit_btn.click( |
| fn=process_input, |
| inputs=[file_input, model_selector, temperature, num_pages, enable_streaming, max_output_tokens], |
| outputs=[output_text, raw_output, page_info, rendered_image, num_pages], |
| ) |
|
|
| file_input.change( |
| fn=update_slider_and_preview, |
| inputs=[file_input], |
| outputs=[num_pages, rendered_image], |
| ) |
|
|
| model_selector.change( |
| fn=get_model_info_text, inputs=[model_selector], outputs=[model_info] |
| ) |
|
|
| clear_btn.click( |
| fn=lambda: ( |
| None, |
| DEFAULT_MODEL, |
| get_model_info_text(DEFAULT_MODEL), |
| "*Extracted text will appear here...*", |
| "", |
| "", |
| None, |
| 1, |
| 2048, |
| ), |
| outputs=[ |
| file_input, |
| model_selector, |
| model_info, |
| output_text, |
| raw_output, |
| page_info, |
| rendered_image, |
| num_pages, |
| max_output_tokens, |
| ], |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| demo.launch(theme=gr.themes.Soft(), ssr_mode=False, share = True) |
|
|