Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI | |
| from pydantic import BaseModel | |
| from transformers import pipeline, AutoTokenizer | |
| import nltk | |
| import os | |
| import uvicorn | |
| import time | |
| from chunker import chunk_by_token_limit | |
| # Setup NLTK directory | |
| NLTK_DATA_DIR = "/app/nltk_data" | |
| os.makedirs(NLTK_DATA_DIR, exist_ok=True) | |
| nltk.data.path.append(NLTK_DATA_DIR) | |
| print("π¦ Downloading NLTK 'punkt' tokenizer...") | |
| nltk.download("punkt", download_dir=NLTK_DATA_DIR, quiet=True) | |
| app = FastAPI() | |
| HF_AUTH_TOKEN = os.getenv("HF_TOKEN") | |
| MODEL_NAME = "VincentMuriuki/legal-summarizer" | |
| print(f"π Loading summarization pipeline: {MODEL_NAME}") | |
| start_model_load = time.time() | |
| summarizer = pipeline("summarization", model=MODEL_NAME, token=HF_AUTH_TOKEN) | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=HF_AUTH_TOKEN) | |
| print(f"β Model loaded in {time.time() - start_model_load:.2f}s") | |
| class SummarizeInput(BaseModel): | |
| text: str | |
| class ChunkInput(BaseModel): | |
| text: str | |
| max_tokens: int = 1024 | |
| def summarize_text(data: SummarizeInput): | |
| print("π₯ Received summarize request.") | |
| start = time.time() | |
| # Chunk the input text | |
| chunks = chunk_by_token_limit(data.text, 1024, tokenizer) | |
| total_chunks = len(chunks) | |
| print(f"π§© Starting summarization over {total_chunks} chunks...") | |
| all_summaries = [] | |
| for idx, chunk in enumerate(chunks): | |
| try: | |
| summary = summarizer(chunk, max_length=150, min_length=30, do_sample=False) | |
| summary_text = summary[0]["summary_text"] | |
| all_summaries.append(summary_text) | |
| print(f"β Summary chunk generated {idx + 1}/{total_chunks}") | |
| except Exception as e: | |
| print(f"β Error summarizing chunk {idx + 1}/{total_chunks}: {str(e)}") | |
| all_summaries.append("[Error summarizing this chunk]") | |
| duration = time.time() - start | |
| print(f"π All {total_chunks} summaries generated in {duration:.2f}s.") | |
| return { | |
| "summaries": all_summaries, | |
| "summary_combined": " ".join(all_summaries), | |
| "chunk_count": total_chunks, | |
| "time_taken": f"{duration:.2f}s" | |
| } | |
| def chunk_text(data: ChunkInput): | |
| print(f"π₯ Received chunking request with max_tokens={data.max_tokens}") | |
| start = time.time() | |
| chunks = chunk_by_token_limit(data.text, data.max_tokens, tokenizer) | |
| duration = time.time() - start | |
| print(f"π Chunking completed in {duration:.2f}s. Total chunks: {len(chunks)}") | |
| return {"chunks": chunks, "chunk_count": len(chunks), "time_taken": f"{duration:.2f}s"} | |
| if __name__ == "__main__": | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |