from fastapi import FastAPI from pydantic import BaseModel from transformers import pipeline, AutoTokenizer import nltk import os import uvicorn import time from chunker import chunk_by_token_limit # Setup NLTK directory NLTK_DATA_DIR = "/app/nltk_data" os.makedirs(NLTK_DATA_DIR, exist_ok=True) nltk.data.path.append(NLTK_DATA_DIR) print("📦 Downloading NLTK 'punkt' tokenizer...") nltk.download("punkt", download_dir=NLTK_DATA_DIR, quiet=True) app = FastAPI() HF_AUTH_TOKEN = os.getenv("HF_TOKEN") MODEL_NAME = "VincentMuriuki/legal-summarizer" print(f"🚀 Loading summarization pipeline: {MODEL_NAME}") start_model_load = time.time() summarizer = pipeline("summarization", model=MODEL_NAME, token=HF_AUTH_TOKEN) tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=HF_AUTH_TOKEN) print(f"✅ Model loaded in {time.time() - start_model_load:.2f}s") class SummarizeInput(BaseModel): text: str class ChunkInput(BaseModel): text: str max_tokens: int = 1024 @app.post("/summarize") def summarize_text(data: SummarizeInput): print("📥 Received summarize request.") start = time.time() # Chunk the input text chunks = chunk_by_token_limit(data.text, 1024, tokenizer) total_chunks = len(chunks) print(f"🧩 Starting summarization over {total_chunks} chunks...") all_summaries = [] for idx, chunk in enumerate(chunks): try: summary = summarizer(chunk, max_length=150, min_length=30, do_sample=False) summary_text = summary[0]["summary_text"] all_summaries.append(summary_text) print(f"✅ Summary chunk generated {idx + 1}/{total_chunks}") except Exception as e: print(f"❌ Error summarizing chunk {idx + 1}/{total_chunks}: {str(e)}") all_summaries.append("[Error summarizing this chunk]") duration = time.time() - start print(f"🏁 All {total_chunks} summaries generated in {duration:.2f}s.") return { "summaries": all_summaries, "summary_combined": " ".join(all_summaries), "chunk_count": total_chunks, "time_taken": f"{duration:.2f}s" } @app.post("/chunk") def chunk_text(data: ChunkInput): print(f"📥 Received chunking request with max_tokens={data.max_tokens}") start = time.time() chunks = chunk_by_token_limit(data.text, data.max_tokens, tokenizer) duration = time.time() - start print(f"🔖 Chunking completed in {duration:.2f}s. Total chunks: {len(chunks)}") return {"chunks": chunks, "chunk_count": len(chunks), "time_taken": f"{duration:.2f}s"} if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=7860)