from fastapi import FastAPI, Request from fastapi.responses import HTMLResponse, JSONResponse from fastapi.staticfiles import StaticFiles from fastapi.templating import Jinja2Templates from pydantic import BaseModel from transformers import pipeline # ---------------- FastAPI + Frontend Setup ---------------- app = FastAPI() # Serve static files (JS, CSS) app.mount("/static", StaticFiles(directory="static"), name="static") # HTML templates templates = Jinja2Templates(directory="templates") class ChatRequest(BaseModel): context: str question: str @app.get("/", response_class=HTMLResponse) async def read_root(request: Request): """Serve main HTML page.""" return templates.TemplateResponse("index.html", {"request": request}) # ---------------- QA Model Setup ---------------- # This is an extractive QA model: it finds the answer span inside the context. # It will download the model the first time the Space builds, then cache it. qa_pipeline = pipeline( "question-answering", model="deepset/roberta-base-squad2", tokenizer="deepset/roberta-base-squad2", ) @app.post("/chat") async def chat_endpoint(payload: ChatRequest): """ Accepts: - context: paragraph / document text - question: user's question about that context Returns: - { "answer": "", "score": float } """ context = payload.context.strip() question = payload.question.strip() if not context or not question: return JSONResponse( {"answer": "Please provide both context and a question."}, status_code=400, ) try: result = qa_pipeline( { "context": context, "question": question, } ) answer = result.get("answer", "").strip() score = float(result.get("score", 0.0)) # Fallback if model fails to find anything reasonable if not answer: answer = "I couldn't find the answer in the given context." return {"answer": answer, "score": score} except Exception as e: return JSONResponse( {"answer": f"Error running QA model: {e}"}, status_code=500, )