sharath88 commited on
Commit
67bcec0
·
1 Parent(s): 6375afa

Switch from OpenAI/Mistral to HuggingFace QA model

Browse files

Replaces the previous OpenAI/Mistral-based chat endpoint with an extractive question-answering pipeline using HuggingFace Transformers and the 'deepset/roberta-base-squad2' model. Updates requirements to use 'transformers' and 'torch', removing 'openai'. The API now returns both the answer and a confidence score.

Files changed (2) hide show
  1. main.py +31 -56
  2. requirements.txt +2 -1
main.py CHANGED
@@ -1,32 +1,11 @@
1
- import os
2
-
3
  from fastapi import FastAPI, Request
4
  from fastapi.responses import HTMLResponse, JSONResponse
5
  from fastapi.staticfiles import StaticFiles
6
  from fastapi.templating import Jinja2Templates
7
  from pydantic import BaseModel
8
- from openai import OpenAI
9
-
10
-
11
- # ---------- Hugging Face router / Mistral config ----------
12
-
13
- HF_TOKEN = os.getenv("HF_TOKEN")
14
- if HF_TOKEN is None:
15
- raise RuntimeError(
16
- "HF_TOKEN environment variable is not set. "
17
- "Go to your Space → Settings → Variables and add HF_TOKEN=<your hf_... token>."
18
- )
19
-
20
- # Use HF router with OpenAI-compatible client
21
- client = OpenAI(
22
- base_url="https://router.huggingface.co/v1",
23
- api_key=HF_TOKEN,
24
- )
25
-
26
- MODEL_ID = "mistralai/Mistral-Nemo-Instruct-2407"
27
-
28
 
29
- # ---------- FastAPI app setup ----------
30
 
31
  app = FastAPI()
32
 
@@ -44,19 +23,30 @@ class ChatRequest(BaseModel):
44
 
45
  @app.get("/", response_class=HTMLResponse)
46
  async def read_root(request: Request):
47
- """Serve the main HTML page."""
48
  return templates.TemplateResponse("index.html", {"request": request})
49
 
50
 
 
 
 
 
 
 
 
 
 
 
 
51
  @app.post("/chat")
52
  async def chat_endpoint(payload: ChatRequest):
53
  """
54
  Accepts:
55
- - context: long paragraph / document text
56
- - question: user question about that context
57
 
58
  Returns:
59
- - { "answer": "<model reply>" }
60
  """
61
  context = payload.context.strip()
62
  question = payload.question.strip()
@@ -67,40 +57,25 @@ async def chat_endpoint(payload: ChatRequest):
67
  status_code=400,
68
  )
69
 
70
- # Build chat-style messages for Mistral via HF router
71
- messages = [
72
- {
73
- "role": "system",
74
- "content": (
75
- "You are a helpful assistant that answers questions ONLY using the "
76
- "given context. If the answer is not in the context, say you don't "
77
- "know and do NOT make up information."
78
- ),
79
- },
80
- {
81
- "role": "user",
82
- "content": (
83
- f"Context:\n{context}\n\n"
84
- f"Question:\n{question}\n\n"
85
- "Answer concisely based only on the context."
86
- ),
87
- },
88
- ]
89
-
90
  try:
91
- completion = client.chat.completions.create(
92
- model=MODEL_ID,
93
- messages=messages,
94
- max_tokens=256,
95
- temperature=0.2,
96
  )
97
 
98
- answer = completion.choices[0].message.content.strip()
99
- return {"answer": answer}
 
 
 
 
 
 
100
 
101
  except Exception as e:
102
- # Return the error message to the frontend so you can see what's wrong
103
  return JSONResponse(
104
- {"answer": f"Error calling model: {e}"},
105
  status_code=500,
106
  )
 
 
 
1
  from fastapi import FastAPI, Request
2
  from fastapi.responses import HTMLResponse, JSONResponse
3
  from fastapi.staticfiles import StaticFiles
4
  from fastapi.templating import Jinja2Templates
5
  from pydantic import BaseModel
6
+ from transformers import pipeline
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
+ # ---------------- FastAPI + Frontend Setup ----------------
9
 
10
  app = FastAPI()
11
 
 
23
 
24
  @app.get("/", response_class=HTMLResponse)
25
  async def read_root(request: Request):
26
+ """Serve main HTML page."""
27
  return templates.TemplateResponse("index.html", {"request": request})
28
 
29
 
30
+ # ---------------- QA Model Setup ----------------
31
+
32
+ # This is an extractive QA model: it finds the answer span inside the context.
33
+ # It will download the model the first time the Space builds, then cache it.
34
+ qa_pipeline = pipeline(
35
+ "question-answering",
36
+ model="deepset/roberta-base-squad2",
37
+ tokenizer="deepset/roberta-base-squad2",
38
+ )
39
+
40
+
41
  @app.post("/chat")
42
  async def chat_endpoint(payload: ChatRequest):
43
  """
44
  Accepts:
45
+ - context: paragraph / document text
46
+ - question: user's question about that context
47
 
48
  Returns:
49
+ - { "answer": "<short answer>", "score": float }
50
  """
51
  context = payload.context.strip()
52
  question = payload.question.strip()
 
57
  status_code=400,
58
  )
59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  try:
61
+ result = qa_pipeline(
62
+ {
63
+ "context": context,
64
+ "question": question,
65
+ }
66
  )
67
 
68
+ answer = result.get("answer", "").strip()
69
+ score = float(result.get("score", 0.0))
70
+
71
+ # Fallback if model fails to find anything reasonable
72
+ if not answer:
73
+ answer = "I couldn't find the answer in the given context."
74
+
75
+ return {"answer": answer, "score": score}
76
 
77
  except Exception as e:
 
78
  return JSONResponse(
79
+ {"answer": f"Error running QA model: {e}"},
80
  status_code=500,
81
  )
requirements.txt CHANGED
@@ -1,4 +1,5 @@
1
  fastapi
2
  uvicorn[standard]
3
  jinja2
4
- openai>=1.50.0
 
 
1
  fastapi
2
  uvicorn[standard]
3
  jinja2
4
+ transformers
5
+ torch