|
|
import os |
|
|
import io |
|
|
import glob |
|
|
import tempfile |
|
|
import time |
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
import requests |
|
|
import gradio as gr |
|
|
from bs4 import BeautifulSoup |
|
|
from PyPDF2 import PdfReader |
|
|
from docx import Document |
|
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
|
from sentence_transformers import SentenceTransformer |
|
|
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM, pipeline |
|
|
import faiss |
|
|
|
|
|
|
|
|
HF_HOME = os.environ.get("HF_HOME", "/tmp/hf_cache") |
|
|
os.makedirs(HF_HOME, exist_ok=True) |
|
|
|
|
|
os.environ["HF_HOME"] = HF_HOME |
|
|
os.environ["TRANSFORMERS_CACHE"] = HF_HOME |
|
|
os.environ["SENTENCE_TRANSFORMERS_HOME"] = HF_HOME |
|
|
os.environ["HF_DATASETS_CACHE"] = HF_HOME |
|
|
os.environ["XDG_CACHE_HOME"] = HF_HOME |
|
|
os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1" |
|
|
|
|
|
locks_dir = os.path.join(HF_HOME, "hub", ".locks") |
|
|
if os.path.isdir(locks_dir): |
|
|
for p in glob.glob(os.path.join(locks_dir, "*.lock")): |
|
|
try: os.remove(p) |
|
|
except: pass |
|
|
|
|
|
MODEL_ID = "MehdiHosseiniMoghadam/AVA-Mistral-7B-V2" |
|
|
|
|
|
embedder = SentenceTransformer("all-MiniLM-L6-v2", cache_folder=HF_HOME) |
|
|
config = AutoConfig.from_pretrained(MODEL_ID, cache_dir=HF_HOME, trust_remote_code=True) |
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, cache_dir=HF_HOME, trust_remote_code=True) |
|
|
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, cache_dir=HF_HOME, trust_remote_code=True) |
|
|
llm = pipeline("text-generation", model=model, tokenizer=tokenizer, |
|
|
max_length=1024, do_sample=True, temperature=0.2, |
|
|
trust_remote_code=True, device_map="auto") |
|
|
|
|
|
def load_file_text(file): |
|
|
name = file.name.lower() |
|
|
if name.endswith(".pdf"): |
|
|
reader = PdfReader(file) |
|
|
text = "".join(page.extract_text() or "" for page in reader.pages) |
|
|
return text |
|
|
elif name.endswith(".docx"): |
|
|
data = file.read() |
|
|
doc = Document(io.BytesIO(data)) |
|
|
return " ".join(p.text for p in doc.paragraphs) |
|
|
elif name.endswith(".csv"): |
|
|
data = file.read() |
|
|
for enc in ("utf-8", "latin-1"): |
|
|
try: |
|
|
df = pd.read_csv(io.BytesIO(data), encoding=enc) |
|
|
return " ".join(df.astype(str).values.flatten().tolist()) |
|
|
except: pass |
|
|
return "" |
|
|
elif name.endswith(".txt"): |
|
|
raw = file.read() |
|
|
for enc in ("utf-8", "latin-1"): |
|
|
try: return raw.decode(enc, errors="ignore") |
|
|
except: continue |
|
|
return raw.decode("utf-8", errors="ignore") |
|
|
else: |
|
|
return "" |
|
|
|
|
|
def fetch_web_text(url): |
|
|
try: |
|
|
headers = {'User-Agent': 'Mozilla/5.0'} |
|
|
resp = requests.get(url, headers=headers, timeout=10) |
|
|
resp.raise_for_status() |
|
|
soup = BeautifulSoup(resp.text, "html.parser") |
|
|
for tag in soup(["script", "style", "noscript"]): |
|
|
tag.decompose() |
|
|
return " ".join(soup.get_text(separator=" ").split()) |
|
|
except Exception: |
|
|
return "" |
|
|
|
|
|
def chunk_docs(docs, chunk_size=1000, chunk_overlap=120): |
|
|
splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap) |
|
|
chunks = [] |
|
|
for doc in docs: |
|
|
splits = splitter.split_text(doc["text"]) |
|
|
for idx, chunk in enumerate(splits): |
|
|
chunks.append({"source": doc["source"], "chunk_id": f"{doc['source']}_chunk{idx}", "content": chunk}) |
|
|
return chunks |
|
|
|
|
|
def build_index_and_chunks(docs): |
|
|
chunks = chunk_docs(docs) |
|
|
texts = [chunk["content"] for chunk in chunks] |
|
|
if len(texts) == 0: return None, [] |
|
|
embeddings = embedder.encode(texts, show_progress_bar=True, convert_to_numpy=True) |
|
|
embeddings = np.asarray(embeddings).astype("float32") |
|
|
dim = embeddings.shape[1] |
|
|
index = faiss.IndexFlatL2(dim) |
|
|
index.add(embeddings) |
|
|
return index, chunks |
|
|
|
|
|
def retrieve(query, index, chunks, top_k=3): |
|
|
if index is None or len(chunks) == 0: |
|
|
return [] |
|
|
q_emb = embedder.encode([query], convert_to_numpy=True) |
|
|
q_emb = np.asarray(q_emb).astype("float32") |
|
|
distances, indices = index.search(q_emb, top_k) |
|
|
results = [] |
|
|
for dist, idx in zip(distances[0], indices[0]): |
|
|
if idx >= 0 and idx < len(chunks): |
|
|
results.append({"chunk": chunks[idx], "score": float(dist)}) |
|
|
return results |
|
|
|
|
|
def answer_question(query, index, chunks): |
|
|
results = retrieve(query, index, chunks) |
|
|
context_chunks = [r["chunk"] for r in results] |
|
|
context_text = "\n".join(f"[{c['chunk_id']}] {c['content']}" for c in context_chunks) |
|
|
prompt = ( |
|
|
"Answer the following question using ONLY the provided context and cite the chunk ids used.\n" |
|
|
f"Question: {query}\nContext:\n{context_text}\nAnswer with citations:" |
|
|
) |
|
|
generated = llm(prompt, max_length=512, num_return_sequences=1) |
|
|
return generated[0]["generated_text"], "\n".join(f"[{c['chunk_id']} from {c['source']}]" for c in context_chunks) |
|
|
|
|
|
state = {"index": None, "chunks": []} |
|
|
|
|
|
def process(files, urls): |
|
|
docs = [] |
|
|
if files: |
|
|
for f in files: |
|
|
text = load_file_text(f) |
|
|
if text: |
|
|
docs.append({"source": f.name, "text": text}) |
|
|
if urls: |
|
|
for url in urls.strip().splitlines(): |
|
|
text = fetch_web_text(url.strip()) |
|
|
if text: |
|
|
docs.append({"source": url.strip(), "text": text}) |
|
|
if len(docs) == 0: |
|
|
return "No documents or URLs loaded." |
|
|
index, chunks = build_index_and_chunks(docs) |
|
|
state["index"], state["chunks"] = index, chunks |
|
|
return f"Loaded {len(docs)} docs, created {len(chunks)} chunks." |
|
|
|
|
|
def chat_response(user_message, history): |
|
|
if state["index"] is None or len(state["chunks"]) == 0: |
|
|
bot_message = "Please upload documents or enter URLs, then press 'Load & Process' first." |
|
|
else: |
|
|
answer, sources = answer_question(user_message, state["index"], state["chunks"]) |
|
|
bot_message = answer + "\n\nSources:\n" + sources |
|
|
history = history or [] |
|
|
history.append(("User: " + user_message, "Assistant: " + bot_message)) |
|
|
return "", history |
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
gr.Markdown("# π RAG Chatbot with Mistral-7B and FAISS") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
file_input = gr.File(label="Upload Files (PDF, DOCX, TXT, CSV)", file_types=[".pdf", ".docx", ".txt", ".csv"], file_count="multiple") |
|
|
url_input = gr.Textbox(label="Enter URLs (one per line)", lines=4) |
|
|
process_button = gr.Button("Load & Process Documents and URLs") |
|
|
output_log = gr.Textbox(label="Status") |
|
|
|
|
|
with gr.Column(scale=2): |
|
|
chatbot = gr.Chatbot() |
|
|
user_input = gr.Textbox(placeholder="Ask a question about the loaded documents...", show_label=False) |
|
|
submit_btn = gr.Button("Send") |
|
|
|
|
|
process_button.click(process, inputs=[file_input, url_input], outputs=output_log) |
|
|
submit_btn.click(chat_response, inputs=[user_input, chatbot], outputs=[user_input, chatbot]) |
|
|
|
|
|
demo.launch() |
|
|
|