| | |
| | """ |
| | MonkeyOCR 3B Gradio App for MacBook M4 Pro with MPS Acceleration |
| | Optimized for local deployment with Apple Silicon GPU acceleration |
| | """ |
| |
|
| | import os |
| | import sys |
| | import tempfile |
| | import shutil |
| | from pathlib import Path |
| | import base64 |
| | import re |
| | import uuid |
| | import subprocess |
| | from typing import Optional, Tuple |
| |
|
| | import gradio as gr |
| | import torch |
| | from PIL import Image |
| | from pdf2image import convert_from_path |
| | from loguru import logger |
| |
|
| | |
| | from torch_patch import patch_torch_load |
| | patch_torch_load() |
| |
|
| | |
| | sys.path.append("./MonkeyOCR") |
| |
|
| | try: |
| | from magic_pdf.data.data_reader_writer import FileBasedDataWriter, FileBasedDataReader |
| | from magic_pdf.data.dataset import PymuDocDataset, ImageDataset |
| | from magic_pdf.model.doc_analyze_by_custom_model_llm import doc_analyze_llm |
| | from magic_pdf.model.custom_model import MonkeyOCR |
| | except ImportError as e: |
| | logger.error(f"Failed to import MonkeyOCR modules: {e}") |
| | logger.info("Please ensure MonkeyOCR is properly installed") |
| | sys.exit(1) |
| |
|
| | |
| | model_instance = None |
| |
|
| | def initialize_model(config_path: str = "model_configs_mps.yaml") -> MonkeyOCR: |
| | """Initialize MonkeyOCR model with MPS optimization""" |
| | global model_instance |
| | |
| | if model_instance is None: |
| | logger.info("Initializing MonkeyOCR model with MPS acceleration...") |
| | |
| | |
| | if not torch.backends.mps.is_available(): |
| | logger.warning("MPS not available, falling back to CPU") |
| | |
| | import yaml |
| | with open(config_path, 'r') as f: |
| | config = yaml.safe_load(f) |
| | config['device'] = 'cpu' |
| | with open(config_path, 'w') as f: |
| | yaml.dump(config, f) |
| | else: |
| | logger.info("MPS is available and will be used for acceleration") |
| | |
| | |
| | os.environ['PYTORCH_MPS_HIGH_WATERMARK_RATIO'] = '0.0' |
| | |
| | try: |
| | model_instance = MonkeyOCR(config_path) |
| | logger.info("Model initialized successfully") |
| | except Exception as e: |
| | logger.error(f"Failed to initialize model: {e}") |
| | raise |
| | |
| | return model_instance |
| |
|
| | def render_latex_table_to_image(latex_content: str, temp_dir: str) -> str: |
| | """Render LaTeX table to image and return HTML img tag""" |
| | try: |
| | |
| | pattern = r"(\\begin\{tabular\}.*?\\end\{tabular\})" |
| | matches = re.findall(pattern, latex_content, re.DOTALL) |
| | |
| | if matches: |
| | table_content = matches[0] |
| | elif '\\begin{tabular}' in latex_content: |
| | if '\\end{tabular}' not in latex_content: |
| | table_content = latex_content + '\n\\end{tabular}' |
| | else: |
| | table_content = latex_content |
| | else: |
| | return latex_content |
| | |
| | |
| | full_latex = r""" |
| | \documentclass{article} |
| | \usepackage[utf8]{inputenc} |
| | \usepackage{booktabs} |
| | \usepackage{bm} |
| | \usepackage{multirow} |
| | \usepackage{array} |
| | \usepackage{colortbl} |
| | \usepackage[table]{xcolor} |
| | \usepackage{amsmath} |
| | \usepackage{amssymb} |
| | \usepackage{graphicx} |
| | \usepackage{geometry} |
| | \usepackage{makecell} |
| | \usepackage[active,tightpage]{preview} |
| | \PreviewEnvironment{tabular} |
| | \begin{document} |
| | """ + table_content + r""" |
| | \end{document} |
| | """ |
| | |
| | |
| | unique_id = str(uuid.uuid4())[:8] |
| | tex_path = os.path.join(temp_dir, f"table_{unique_id}.tex") |
| | pdf_path = os.path.join(temp_dir, f"table_{unique_id}.pdf") |
| | png_path = os.path.join(temp_dir, f"table_{unique_id}.png") |
| | |
| | |
| | with open(tex_path, "w", encoding="utf-8") as f: |
| | f.write(full_latex) |
| | |
| | |
| | result = subprocess.run( |
| | ["pdflatex", "-interaction=nonstopmode", "-output-directory", temp_dir, tex_path], |
| | timeout=20, |
| | capture_output=True, |
| | text=True |
| | ) |
| | |
| | if result.returncode != 0 or not os.path.exists(pdf_path): |
| | logger.warning("LaTeX compilation failed, returning original content") |
| | return f"<pre>{latex_content}</pre>" |
| | |
| | |
| | images = convert_from_path(pdf_path, dpi=300) |
| | images[0].save(png_path, "PNG") |
| | |
| | |
| | with open(png_path, "rb") as f: |
| | img_data = f.read() |
| | img_base64 = base64.b64encode(img_data).decode("utf-8") |
| | |
| | |
| | for file_path in [tex_path, pdf_path, png_path]: |
| | if os.path.exists(file_path): |
| | os.remove(file_path) |
| | |
| | return f'<img src="data:image/png;base64,{img_base64}" style="max-width:100%;height:auto;">' |
| | |
| | except Exception as e: |
| | logger.warning(f"LaTeX rendering error: {e}") |
| | return f"<pre>{latex_content}</pre>" |
| |
|
| | def process_document(file_path: str) -> Tuple[str, str]: |
| | """Process document and return markdown content and layout PDF path""" |
| | if not file_path: |
| | return "", "" |
| | |
| | try: |
| | model = initialize_model() |
| | |
| | parent_path = os.path.dirname(file_path) |
| | full_name = os.path.basename(file_path) |
| | name = '.'.join(full_name.split(".")[:-1]) |
| | |
| | |
| | local_image_dir = os.path.join(parent_path, "markdown", "images") |
| | local_md_dir = os.path.join(parent_path, "markdown") |
| | os.makedirs(local_image_dir, exist_ok=True) |
| | os.makedirs(local_md_dir, exist_ok=True) |
| | |
| | image_dir = os.path.basename(local_image_dir) |
| | image_writer = FileBasedDataWriter(local_image_dir) |
| | md_writer = FileBasedDataWriter(local_md_dir) |
| | reader = FileBasedDataReader(parent_path) |
| | |
| | |
| | data_bytes = reader.read(full_name) |
| | |
| | |
| | if full_name.split(".")[-1].lower() in ['jpg', 'jpeg', 'png']: |
| | ds = ImageDataset(data_bytes) |
| | else: |
| | ds = PymuDocDataset(data_bytes) |
| | |
| | |
| | logger.info("Processing document with MonkeyOCR...") |
| | |
| | import threading |
| | import time |
| | from concurrent.futures import ThreadPoolExecutor, TimeoutError as FutureTimeoutError |
| | |
| | def process_with_model(): |
| | overall_start_time = time.time() |
| | |
| | |
| | analysis_start_time = time.time() |
| | logger.info("Starting document analysis...") |
| | infer_result = ds.apply(doc_analyze_llm, MonkeyOCR_model=model) |
| | logger.info(f"PROFILE: Document analysis (doc_analyze_llm) took {time.time() - analysis_start_time:.2f}s") |
| |
|
| | |
| | ocr_start_time = time.time() |
| | logger.info("Starting OCR and layout processing...") |
| | pipe_result = infer_result.pipe_ocr_mode(image_writer, MonkeyOCR_model=model) |
| | logger.info(f"PROFILE: OCR/Layout (pipe_ocr_mode) took {time.time() - ocr_start_time:.2f}s") |
| | |
| | logger.info(f"PROFILE: Total model processing took {time.time() - overall_start_time:.2f}s") |
| | return infer_result, pipe_result |
| | |
| | |
| | with ThreadPoolExecutor(max_workers=1) as executor: |
| | future = executor.submit(process_with_model) |
| | try: |
| | infer_result, pipe_result = future.result(timeout=300) |
| | except FutureTimeoutError: |
| | logger.error("Processing timed out after 5 minutes") |
| | raise TimeoutError("Document processing timed out. Please try with a smaller document or simpler layout.") |
| | |
| | |
| | layout_pdf_path = os.path.join(parent_path, f"{name}_layout.pdf") |
| | pipe_result.draw_layout(layout_pdf_path) |
| | |
| | |
| | pipe_result.dump_md(md_writer, f"{name}.md", image_dir) |
| | md_content_ori = FileBasedDataReader(local_md_dir).read(f"{name}.md").decode("utf-8") |
| | |
| | |
| | temp_dir = tempfile.mkdtemp() |
| | try: |
| | |
| | def replace_html_latex_table(match): |
| | html_content = match.group(1) |
| | if '\\begin{tabular}' in html_content: |
| | return render_latex_table_to_image(html_content, temp_dir) |
| | else: |
| | return match.group(0) |
| | |
| | md_content = re.sub(r'<html>(.*?)</html>', replace_html_latex_table, md_content_ori, flags=re.DOTALL) |
| | |
| | |
| | def replace_image_with_base64(match): |
| | img_path = match.group(1) |
| | if not os.path.isabs(img_path): |
| | full_img_path = os.path.join(local_md_dir, img_path) |
| | else: |
| | full_img_path = img_path |
| | |
| | try: |
| | if os.path.exists(full_img_path): |
| | with open(full_img_path, "rb") as f: |
| | img_data = f.read() |
| | img_base64 = base64.b64encode(img_data).decode("utf-8") |
| | ext = os.path.splitext(full_img_path)[1].lower() |
| | mime_type = "image/jpeg" if ext in ['.jpg', '.jpeg'] else f"image/{ext[1:]}" |
| | return f'<img src="data:{mime_type};base64,{img_base64}" style="max-width:100%;height:auto;">' |
| | else: |
| | return match.group(0) |
| | except Exception: |
| | return match.group(0) |
| | |
| | md_content = re.sub(r'!\[.*?\]\(([^)]+)\)', replace_image_with_base64, md_content) |
| | |
| | finally: |
| | if os.path.exists(temp_dir): |
| | shutil.rmtree(temp_dir, ignore_errors=True) |
| | |
| | logger.info("Document processing completed successfully") |
| | return md_content, layout_pdf_path |
| | |
| | except Exception as e: |
| | logger.error(f"Error processing document: {e}") |
| | return f"Error processing document: {str(e)}", "" |
| |
|
| | def parse_document(file) -> Tuple[str, Optional[str]]: |
| | """Parse uploaded document and return results""" |
| | if file is None: |
| | return "Please upload a document first.", None |
| | |
| | try: |
| | |
| | markdown_content, layout_pdf_path = process_document(file.name) |
| | |
| | if not markdown_content: |
| | return "Failed to process document.", None |
| | |
| | return markdown_content, layout_pdf_path if os.path.exists(layout_pdf_path) else None |
| | |
| | except Exception as e: |
| | logger.error(f"Error in parse_document: {e}") |
| | return f"Error: {str(e)}", None |
| |
|
| | def create_gradio_interface(): |
| | """Create and configure Gradio interface""" |
| | |
| | |
| | css = """ |
| | .gradio-container { |
| | max-width: 1200px !important; |
| | } |
| | .markdown-content { |
| | max-height: 600px; |
| | overflow-y: auto; |
| | border: 1px solid #ddd; |
| | padding: 10px; |
| | border-radius: 5px; |
| | } |
| | """ |
| | |
| | with gr.Blocks( |
| | title="MonkeyOCR 3B - Local MPS Demo", |
| | css=css, |
| | theme=gr.themes.Soft() |
| | ) as demo: |
| | |
| | gr.Markdown(""" |
| | # π΅ MonkeyOCR 3B - Local Demo (Apple Silicon MPS) |
| | |
| | **Optimized for MacBook M4 Pro with 48GB RAM** |
| | |
| | Upload a PDF or image document to extract structured content with state-of-the-art accuracy. |
| | The model runs locally using Apple's Metal Performance Shaders for GPU acceleration. |
| | |
| | **Supported formats:** PDF, PNG, JPG, JPEG |
| | """) |
| | |
| | with gr.Row(): |
| | with gr.Column(scale=1): |
| | file_input = gr.File( |
| | label="π Upload Document", |
| | file_types=[".pdf", ".png", ".jpg", ".jpeg"], |
| | type="filepath" |
| | ) |
| | |
| | parse_btn = gr.Button( |
| | "π Parse Document", |
| | variant="primary", |
| | size="lg" |
| | ) |
| | |
| | gr.Markdown(""" |
| | **Tips:** |
| | - Larger documents may take a few minutes to process |
| | - The model excels at formulas, tables, and complex layouts |
| | - Processing speed: ~0.84 pages/second on M4 Pro |
| | """) |
| | |
| | with gr.Column(scale=2): |
| | markdown_output = gr.Markdown( |
| | label="π Extracted Content", |
| | elem_classes=["markdown-content"] |
| | ) |
| | |
| | layout_pdf_output = gr.File( |
| | label="π Layout Analysis (PDF)", |
| | visible=False |
| | ) |
| | |
| | |
| | parse_btn.click( |
| | fn=parse_document, |
| | inputs=[file_input], |
| | outputs=[markdown_output, layout_pdf_output], |
| | show_progress=True |
| | ) |
| | |
| | |
| | def show_layout_pdf(pdf_path): |
| | if pdf_path: |
| | return gr.update(visible=True, value=pdf_path) |
| | return gr.update(visible=False) |
| | |
| | layout_pdf_output.change( |
| | fn=show_layout_pdf, |
| | inputs=[layout_pdf_output], |
| | outputs=[layout_pdf_output] |
| | ) |
| | |
| | return demo |
| |
|
| | def main(): |
| | """Main function to run the Gradio app""" |
| | logger.info("Starting MonkeyOCR 3B Gradio App...") |
| | |
| | |
| | if not torch.backends.mps.is_available(): |
| | logger.warning("MPS not available. The app will run on CPU which may be slower.") |
| | else: |
| | logger.info("MPS is available. GPU acceleration enabled.") |
| | |
| | |
| | demo = create_gradio_interface() |
| | |
| | |
| | demo.launch( |
| | server_name="0.0.0.0", |
| | server_port=7861, |
| | share=False, |
| | show_error=True, |
| | quiet=False |
| | ) |
| |
|
| | if __name__ == "__main__": |
| | main() |