Rah / app.py
aach456's picture
Update app.py
7c22e3c verified
import os
import io
import glob
import tempfile
import time
import numpy as np
import pandas as pd
import requests
import gradio as gr
from bs4 import BeautifulSoup
from PyPDF2 import PdfReader
from docx import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from sentence_transformers import SentenceTransformer
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM, pipeline
import faiss
# Setup HF cache paths before imports
HF_HOME = os.environ.get("HF_HOME", "/tmp/hf_cache")
os.makedirs(HF_HOME, exist_ok=True)
os.environ["HF_HOME"] = HF_HOME
os.environ["TRANSFORMERS_CACHE"] = HF_HOME
os.environ["SENTENCE_TRANSFORMERS_HOME"] = HF_HOME
os.environ["HF_DATASETS_CACHE"] = HF_HOME
os.environ["XDG_CACHE_HOME"] = HF_HOME
os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"
locks_dir = os.path.join(HF_HOME, "hub", ".locks")
if os.path.isdir(locks_dir):
for p in glob.glob(os.path.join(locks_dir, "*.lock")):
try: os.remove(p)
except: pass
MODEL_ID = "MehdiHosseiniMoghadam/AVA-Mistral-7B-V2"
embedder = SentenceTransformer("all-MiniLM-L6-v2", cache_folder=HF_HOME)
config = AutoConfig.from_pretrained(MODEL_ID, cache_dir=HF_HOME, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, cache_dir=HF_HOME, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, cache_dir=HF_HOME, trust_remote_code=True)
llm = pipeline("text-generation", model=model, tokenizer=tokenizer,
max_length=1024, do_sample=True, temperature=0.2,
trust_remote_code=True, device_map="auto")
def load_file_text(file):
name = file.name.lower()
if name.endswith(".pdf"):
reader = PdfReader(file)
text = "".join(page.extract_text() or "" for page in reader.pages)
return text
elif name.endswith(".docx"):
data = file.read()
doc = Document(io.BytesIO(data))
return " ".join(p.text for p in doc.paragraphs)
elif name.endswith(".csv"):
data = file.read()
for enc in ("utf-8", "latin-1"):
try:
df = pd.read_csv(io.BytesIO(data), encoding=enc)
return " ".join(df.astype(str).values.flatten().tolist())
except: pass
return ""
elif name.endswith(".txt"):
raw = file.read()
for enc in ("utf-8", "latin-1"):
try: return raw.decode(enc, errors="ignore")
except: continue
return raw.decode("utf-8", errors="ignore")
else:
return ""
def fetch_web_text(url):
try:
headers = {'User-Agent': 'Mozilla/5.0'}
resp = requests.get(url, headers=headers, timeout=10)
resp.raise_for_status()
soup = BeautifulSoup(resp.text, "html.parser")
for tag in soup(["script", "style", "noscript"]):
tag.decompose()
return " ".join(soup.get_text(separator=" ").split())
except Exception:
return ""
def chunk_docs(docs, chunk_size=1000, chunk_overlap=120):
splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
chunks = []
for doc in docs:
splits = splitter.split_text(doc["text"])
for idx, chunk in enumerate(splits):
chunks.append({"source": doc["source"], "chunk_id": f"{doc['source']}_chunk{idx}", "content": chunk})
return chunks
def build_index_and_chunks(docs):
chunks = chunk_docs(docs)
texts = [chunk["content"] for chunk in chunks]
if len(texts) == 0: return None, []
embeddings = embedder.encode(texts, show_progress_bar=True, convert_to_numpy=True)
embeddings = np.asarray(embeddings).astype("float32")
dim = embeddings.shape[1]
index = faiss.IndexFlatL2(dim)
index.add(embeddings)
return index, chunks
def retrieve(query, index, chunks, top_k=3):
if index is None or len(chunks) == 0:
return []
q_emb = embedder.encode([query], convert_to_numpy=True)
q_emb = np.asarray(q_emb).astype("float32")
distances, indices = index.search(q_emb, top_k)
results = []
for dist, idx in zip(distances[0], indices[0]):
if idx >= 0 and idx < len(chunks):
results.append({"chunk": chunks[idx], "score": float(dist)})
return results
def answer_question(query, index, chunks):
results = retrieve(query, index, chunks)
context_chunks = [r["chunk"] for r in results]
context_text = "\n".join(f"[{c['chunk_id']}] {c['content']}" for c in context_chunks)
prompt = (
"Answer the following question using ONLY the provided context and cite the chunk ids used.\n"
f"Question: {query}\nContext:\n{context_text}\nAnswer with citations:"
)
generated = llm(prompt, max_length=512, num_return_sequences=1)
return generated[0]["generated_text"], "\n".join(f"[{c['chunk_id']} from {c['source']}]" for c in context_chunks)
state = {"index": None, "chunks": []}
def process(files, urls):
docs = []
if files:
for f in files:
text = load_file_text(f)
if text:
docs.append({"source": f.name, "text": text})
if urls:
for url in urls.strip().splitlines():
text = fetch_web_text(url.strip())
if text:
docs.append({"source": url.strip(), "text": text})
if len(docs) == 0:
return "No documents or URLs loaded."
index, chunks = build_index_and_chunks(docs)
state["index"], state["chunks"] = index, chunks
return f"Loaded {len(docs)} docs, created {len(chunks)} chunks."
def chat_response(user_message, history):
if state["index"] is None or len(state["chunks"]) == 0:
bot_message = "Please upload documents or enter URLs, then press 'Load & Process' first."
else:
answer, sources = answer_question(user_message, state["index"], state["chunks"])
bot_message = answer + "\n\nSources:\n" + sources
history = history or []
history.append(("User: " + user_message, "Assistant: " + bot_message))
return "", history
with gr.Blocks() as demo:
gr.Markdown("# πŸ“š RAG Chatbot with Mistral-7B and FAISS")
with gr.Row():
with gr.Column(scale=1):
file_input = gr.File(label="Upload Files (PDF, DOCX, TXT, CSV)", file_types=[".pdf", ".docx", ".txt", ".csv"], file_count="multiple")
url_input = gr.Textbox(label="Enter URLs (one per line)", lines=4)
process_button = gr.Button("Load & Process Documents and URLs")
output_log = gr.Textbox(label="Status")
with gr.Column(scale=2):
chatbot = gr.Chatbot()
user_input = gr.Textbox(placeholder="Ask a question about the loaded documents...", show_label=False)
submit_btn = gr.Button("Send")
process_button.click(process, inputs=[file_input, url_input], outputs=output_log)
submit_btn.click(chat_response, inputs=[user_input, chatbot], outputs=[user_input, chatbot])
demo.launch()