sagar008's picture
Update app.py
17c9aaf verified
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import pipeline, AutoTokenizer
import nltk
import os
import uvicorn
import time
from chunker import chunk_by_token_limit
# Setup NLTK directory
NLTK_DATA_DIR = "/app/nltk_data"
os.makedirs(NLTK_DATA_DIR, exist_ok=True)
nltk.data.path.append(NLTK_DATA_DIR)
print("πŸ“¦ Downloading NLTK 'punkt' tokenizer...")
nltk.download("punkt", download_dir=NLTK_DATA_DIR, quiet=True)
app = FastAPI()
HF_AUTH_TOKEN = os.getenv("HF_TOKEN")
MODEL_NAME = "VincentMuriuki/legal-summarizer"
print(f"πŸš€ Loading summarization pipeline: {MODEL_NAME}")
start_model_load = time.time()
summarizer = pipeline("summarization", model=MODEL_NAME, token=HF_AUTH_TOKEN)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=HF_AUTH_TOKEN)
print(f"βœ… Model loaded in {time.time() - start_model_load:.2f}s")
class SummarizeInput(BaseModel):
text: str
class ChunkInput(BaseModel):
text: str
max_tokens: int = 1024
@app.post("/summarize")
def summarize_text(data: SummarizeInput):
print("πŸ“₯ Received summarize request.")
start = time.time()
# Chunk the input text
chunks = chunk_by_token_limit(data.text, 1024, tokenizer)
total_chunks = len(chunks)
print(f"🧩 Starting summarization over {total_chunks} chunks...")
all_summaries = []
for idx, chunk in enumerate(chunks):
try:
summary = summarizer(chunk, max_length=150, min_length=30, do_sample=False)
summary_text = summary[0]["summary_text"]
all_summaries.append(summary_text)
print(f"βœ… Summary chunk generated {idx + 1}/{total_chunks}")
except Exception as e:
print(f"❌ Error summarizing chunk {idx + 1}/{total_chunks}: {str(e)}")
all_summaries.append("[Error summarizing this chunk]")
duration = time.time() - start
print(f"🏁 All {total_chunks} summaries generated in {duration:.2f}s.")
return {
"summaries": all_summaries,
"summary_combined": " ".join(all_summaries),
"chunk_count": total_chunks,
"time_taken": f"{duration:.2f}s"
}
@app.post("/chunk")
def chunk_text(data: ChunkInput):
print(f"πŸ“₯ Received chunking request with max_tokens={data.max_tokens}")
start = time.time()
chunks = chunk_by_token_limit(data.text, data.max_tokens, tokenizer)
duration = time.time() - start
print(f"πŸ”– Chunking completed in {duration:.2f}s. Total chunks: {len(chunks)}")
return {"chunks": chunks, "chunk_count": len(chunks), "time_taken": f"{duration:.2f}s"}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)