| | import base64 |
| | import copy |
| | from datetime import datetime |
| | import json |
| | import fire |
| | import os |
| | import pathlib |
| |
|
| | from poster.figures import extract_figures |
| | from poster.poster import ( |
| | generate_html_v2, |
| | generate_poster_v3, |
| | replace_figures_in_poster, |
| | replace_figures_size_in_poster, |
| | ) |
| |
|
| |
|
| | def generate_paper_poster( |
| | url: str, |
| | pdf: str, |
| | vendor: str = "openai", |
| | model: str = "gpt-4o-mini", |
| | text_prompt: str = "", |
| | figures_prompt: str = "", |
| | output: str = "poster.json", |
| | ): |
| | """Generate a paper poster |
| | |
| | Args: |
| | url: URL of the PDF file |
| | pdf: Local path of the PDF file |
| | model: Name of the model to use, default is gpt-4o-mini |
| | text_prompt: Text prompt template, |
| | figures_prompt: Figures prompt template, |
| | output: Output file path, default is poster.json |
| | """ |
| | pdf_stem = pdf.replace(".pdf", "") |
| | figures_cache = f"{pdf_stem}_figures.json" |
| | figures_cap_cache = f"{pdf_stem}_figures_cap.json" |
| |
|
| | figures = [] |
| | |
| | print("开始提取图片...") |
| | if os.path.exists(figures_cache) and os.path.exists(figures_cap_cache): |
| | print(f"使用缓存的图片: {figures_cache}") |
| | with open(figures_cache, "r") as f: |
| | figures = json.load(f) |
| | |
| | |
| | else: |
| | figures_img = extract_figures(url, pdf, task="figure") |
| | figures_table = extract_figures(url, pdf, task="table") |
| | |
| | |
| | threshold = 0.75 |
| | |
| | figures = [ |
| | image |
| | for image, score in figures_img + figures_table |
| | if score >= threshold |
| | ] |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | with open(figures_cache, "w") as f: |
| | json.dump(figures, f, ensure_ascii=False) |
| | |
| | |
| |
|
| | print("开始生成海报...") |
| | max_attempts = 3 |
| | attempt = 0 |
| | while True: |
| | try: |
| | result = generate_poster_v3( |
| | vendor, model, text_prompt, figures_prompt, pdf, figures, figures |
| | ) |
| |
|
| | poster = result["image_based_poster"] |
| | backup_poster = copy.deepcopy(poster) |
| |
|
| | poster = replace_figures_in_poster(poster, figures) |
| |
|
| | |
| | |
| |
|
| | poster_size = replace_figures_size_in_poster(backup_poster, figures) |
| | print(f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Now generating HTML...") |
| | result = generate_html_v2(vendor, model, poster_size, figures) |
| |
|
| | html = result["html_with_figures"] |
| |
|
| | |
| | |
| | |
| | print("海报生成成功!") |
| | return poster, html |
| |
|
| | except Exception as e: |
| | if ( |
| | "content management policy" in str(e) |
| | or "message larger than max" in str(e) |
| | or "exceeds the maximum length" in str(e) |
| | or "maximum context length" in str(e) |
| | or "Input is too long" in str(e) |
| | or "image exceeds 5 MB" in str(e) |
| | or "too many total text bytes" in str(e) |
| | or "Range of input length" in str(e) |
| | or "Invalid text" in str(e) |
| | ): |
| | raise |
| | print(f"处理文件 {pdf} 时出错: {e}") |
| | attempt += 1 |
| | if attempt > max_attempts: |
| | return None, None |
| |
|
| |
|
| |
|
| | if __name__ == "__main__": |
| | fire.Fire(generate_paper_poster) |
| |
|