File size: 6,972 Bytes
7c22e3c
fc96ea6
7c22e3c
 
fc96ea6
 
7c22e3c
fc96ea6
7c22e3c
fc96ea6
 
 
 
 
 
 
 
7c22e3c
 
 
fc96ea6
7c22e3c
 
 
 
 
 
fc96ea6
7c22e3c
 
 
 
 
fc96ea6
7c22e3c
fc96ea6
7c22e3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fc96ea6
 
 
7c22e3c
fc96ea6
 
 
 
7c22e3c
 
 
 
 
 
 
 
 
 
 
 
 
fc96ea6
 
7c22e3c
fc96ea6
7c22e3c
 
 
 
 
 
 
 
 
fc96ea6
 
 
7c22e3c
fc96ea6
7c22e3c
 
fc96ea6
7c22e3c
fc96ea6
7c22e3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
import os
import io
import glob
import tempfile
import time
import numpy as np
import pandas as pd
import requests
import gradio as gr
from bs4 import BeautifulSoup
from PyPDF2 import PdfReader
from docx import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from sentence_transformers import SentenceTransformer
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM, pipeline
import faiss

# Setup HF cache paths before imports
HF_HOME = os.environ.get("HF_HOME", "/tmp/hf_cache")
os.makedirs(HF_HOME, exist_ok=True)

os.environ["HF_HOME"] = HF_HOME
os.environ["TRANSFORMERS_CACHE"] = HF_HOME
os.environ["SENTENCE_TRANSFORMERS_HOME"] = HF_HOME
os.environ["HF_DATASETS_CACHE"] = HF_HOME
os.environ["XDG_CACHE_HOME"] = HF_HOME
os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"

locks_dir = os.path.join(HF_HOME, "hub", ".locks")
if os.path.isdir(locks_dir):
    for p in glob.glob(os.path.join(locks_dir, "*.lock")):
        try: os.remove(p)
        except: pass

MODEL_ID = "MehdiHosseiniMoghadam/AVA-Mistral-7B-V2"

embedder = SentenceTransformer("all-MiniLM-L6-v2", cache_folder=HF_HOME)
config = AutoConfig.from_pretrained(MODEL_ID, cache_dir=HF_HOME, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, cache_dir=HF_HOME, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, cache_dir=HF_HOME, trust_remote_code=True)
llm = pipeline("text-generation", model=model, tokenizer=tokenizer,
               max_length=1024, do_sample=True, temperature=0.2,
               trust_remote_code=True, device_map="auto")

def load_file_text(file):
    name = file.name.lower()
    if name.endswith(".pdf"):
        reader = PdfReader(file)
        text = "".join(page.extract_text() or "" for page in reader.pages)
        return text
    elif name.endswith(".docx"):
        data = file.read()
        doc = Document(io.BytesIO(data))
        return " ".join(p.text for p in doc.paragraphs)
    elif name.endswith(".csv"):
        data = file.read()
        for enc in ("utf-8", "latin-1"):
            try:
                df = pd.read_csv(io.BytesIO(data), encoding=enc)
                return " ".join(df.astype(str).values.flatten().tolist())
            except: pass
        return ""
    elif name.endswith(".txt"):
        raw = file.read()
        for enc in ("utf-8", "latin-1"):
            try: return raw.decode(enc, errors="ignore")
            except: continue
        return raw.decode("utf-8", errors="ignore")
    else:
        return ""

def fetch_web_text(url):
    try:
        headers = {'User-Agent': 'Mozilla/5.0'}
        resp = requests.get(url, headers=headers, timeout=10)
        resp.raise_for_status()
        soup = BeautifulSoup(resp.text, "html.parser")
        for tag in soup(["script", "style", "noscript"]):
            tag.decompose()
        return " ".join(soup.get_text(separator=" ").split())
    except Exception:
        return ""

def chunk_docs(docs, chunk_size=1000, chunk_overlap=120):
    splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
    chunks = []
    for doc in docs:
        splits = splitter.split_text(doc["text"])
        for idx, chunk in enumerate(splits):
            chunks.append({"source": doc["source"], "chunk_id": f"{doc['source']}_chunk{idx}", "content": chunk})
    return chunks

def build_index_and_chunks(docs):
    chunks = chunk_docs(docs)
    texts = [chunk["content"] for chunk in chunks]
    if len(texts) == 0: return None, []
    embeddings = embedder.encode(texts, show_progress_bar=True, convert_to_numpy=True)
    embeddings = np.asarray(embeddings).astype("float32")
    dim = embeddings.shape[1]
    index = faiss.IndexFlatL2(dim)
    index.add(embeddings)
    return index, chunks

def retrieve(query, index, chunks, top_k=3):
    if index is None or len(chunks) == 0:
        return []
    q_emb = embedder.encode([query], convert_to_numpy=True)
    q_emb = np.asarray(q_emb).astype("float32")
    distances, indices = index.search(q_emb, top_k)
    results = []
    for dist, idx in zip(distances[0], indices[0]):
        if idx >= 0 and idx < len(chunks):
            results.append({"chunk": chunks[idx], "score": float(dist)})
    return results

def answer_question(query, index, chunks):
    results = retrieve(query, index, chunks)
    context_chunks = [r["chunk"] for r in results]
    context_text = "\n".join(f"[{c['chunk_id']}] {c['content']}" for c in context_chunks)
    prompt = (
        "Answer the following question using ONLY the provided context and cite the chunk ids used.\n"
        f"Question: {query}\nContext:\n{context_text}\nAnswer with citations:"
    )
    generated = llm(prompt, max_length=512, num_return_sequences=1)
    return generated[0]["generated_text"], "\n".join(f"[{c['chunk_id']} from {c['source']}]" for c in context_chunks)

state = {"index": None, "chunks": []}

def process(files, urls):
    docs = []
    if files:
        for f in files:
            text = load_file_text(f)
            if text:
                docs.append({"source": f.name, "text": text})
    if urls:
        for url in urls.strip().splitlines():
            text = fetch_web_text(url.strip())
            if text:
                docs.append({"source": url.strip(), "text": text})
    if len(docs) == 0:
        return "No documents or URLs loaded."
    index, chunks = build_index_and_chunks(docs)
    state["index"], state["chunks"] = index, chunks
    return f"Loaded {len(docs)} docs, created {len(chunks)} chunks."

def chat_response(user_message, history):
    if state["index"] is None or len(state["chunks"]) == 0:
        bot_message = "Please upload documents or enter URLs, then press 'Load & Process' first."
    else:
        answer, sources = answer_question(user_message, state["index"], state["chunks"])
        bot_message = answer + "\n\nSources:\n" + sources
    history = history or []
    history.append(("User: " + user_message, "Assistant: " + bot_message))
    return "", history

with gr.Blocks() as demo:
    gr.Markdown("# πŸ“š RAG Chatbot with Mistral-7B and FAISS")

    with gr.Row():
        with gr.Column(scale=1):
            file_input = gr.File(label="Upload Files (PDF, DOCX, TXT, CSV)", file_types=[".pdf", ".docx", ".txt", ".csv"], file_count="multiple")
            url_input = gr.Textbox(label="Enter URLs (one per line)", lines=4)
            process_button = gr.Button("Load & Process Documents and URLs")
            output_log = gr.Textbox(label="Status")

        with gr.Column(scale=2):
            chatbot = gr.Chatbot()
            user_input = gr.Textbox(placeholder="Ask a question about the loaded documents...", show_label=False)
            submit_btn = gr.Button("Send")

    process_button.click(process, inputs=[file_input, url_input], outputs=output_log)
    submit_btn.click(chat_response, inputs=[user_input, chatbot], outputs=[user_input, chatbot])

demo.launch()