aach456 commited on
Commit
fc96ea6
·
verified ·
1 Parent(s): 67b085a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +270 -0
app.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ # Chat-style RAG app with Streamlit chat UI, FAISS retrieval, SentenceTransformers embeddings,
3
+ # and an open Mistral-7B pipeline. All caches redirected to /tmp to avoid PermissionError.
4
+
5
+ # ---------- Writable dirs BEFORE third-party imports ----------
6
+ import os, glob, tempfile
7
+ # Streamlit internal runtime dir -> /tmp (fixes PermissionError: '/.streamlit')
8
+ ST_RT = os.environ.get("STREAMLIT_RUNTIME_DIR", "/tmp/.streamlit_runtime")
9
+ try:
10
+ os.makedirs(ST_RT, exist_ok=True)
11
+ except Exception:
12
+ ST_RT = tempfile.mkdtemp(prefix="st_runtime_")
13
+ os.environ["STREAMLIT_RUNTIME_DIR"] = ST_RT
14
+
15
+ # Hugging Face caches -> /tmp
16
+ HF_HOME = os.environ.get("HF_HOME", "/tmp/hf_cache")
17
+ try:
18
+ os.makedirs(HF_HOME, exist_ok=True)
19
+ except Exception:
20
+ HF_HOME = tempfile.mkdtemp(prefix="hf_cache_")
21
+ os.environ["HF_HOME"] = HF_HOME
22
+ os.environ["TRANSFORMERS_CACHE"] = HF_HOME # backward-compat; deprecation warning is harmless
23
+ os.environ["SENTENCE_TRANSFORMERS_HOME"] = HF_HOME
24
+ os.environ["HF_DATASETS_CACHE"] = HF_HOME
25
+ os.environ["XDG_CACHE_HOME"] = HF_HOME
26
+ os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"
27
+
28
+ # Clean stale locks
29
+ locks_dir = os.path.join(HF_HOME, "hub", ".locks")
30
+ if os.path.isdir(locks_dir):
31
+ for p in glob.glob(os.path.join(locks_dir, "*.lock")):
32
+ try:
33
+ os.remove(p)
34
+ except Exception:
35
+ pass
36
+
37
+ # ---------- Imports AFTER env is set ----------
38
+ import io
39
+ import time
40
+ import pandas as pd
41
+ import numpy as np
42
+ import requests
43
+ import streamlit as st
44
+ from bs4 import BeautifulSoup
45
+ from PyPDF2 import PdfReader
46
+ from docx import Document
47
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
48
+ from sentence_transformers import SentenceTransformer
49
+ from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM, pipeline
50
+ import faiss
51
+
52
+ # ---------- Page ----------
53
+ st.set_page_config(page_title="Chat RAG • Open Model + URLs", layout="wide")
54
+ st.title("💬 Chat RAG with Open Model, FAISS, and Web URLs")
55
+
56
+ # ---------- Session ----------
57
+ for key, default in [
58
+ ("messages", []),
59
+ ("chunks", []),
60
+ ("embedder", None),
61
+ ("faiss_index", None),
62
+ ]:
63
+ if key not in st.session_state:
64
+ st.session_state[key] = default
65
+
66
+ # ---------- Loaders ----------
67
+ def load_txt(file):
68
+ raw = file.read()
69
+ for enc in ("utf-8", "latin-1"):
70
+ try:
71
+ return [{"source": file.name, "text": raw.decode(enc, errors="ignore")}]
72
+ except Exception:
73
+ continue
74
+ return [{"source": file.name, "text": raw.decode("utf-8", errors="ignore")}]
75
+
76
+ def load_pdf(file):
77
+ pdf = PdfReader(file)
78
+ text = ""
79
+ for page in pdf.pages:
80
+ text += page.extract_text() or ""
81
+ return [{"source": file.name, "text": text}]
82
+
83
+ def load_docx(file):
84
+ data = file.read()
85
+ doc = Document(io.BytesIO(data))
86
+ text = " ".join(p.text for p in doc.paragraphs)
87
+ return [{"source": file.name, "text": text}]
88
+
89
+ def load_csv(file):
90
+ data = file.read()
91
+ df = None
92
+ for enc in ("utf-8", "latin-1"):
93
+ try:
94
+ df = pd.read_csv(io.BytesIO(data), encoding=enc)
95
+ break
96
+ except Exception:
97
+ df = None
98
+ if df is None:
99
+ try:
100
+ df = pd.read_csv(io.BytesIO(data), engine="python")
101
+ except Exception:
102
+ df = pd.DataFrame()
103
+ text = " ".join(df.astype(str).values.flatten().tolist()) if not df.empty else ""
104
+ return [{"source": file.name, "text": text}]
105
+
106
+ def load_documents(files):
107
+ docs = []
108
+ for file in files or []:
109
+ name = file.name.lower()
110
+ if name.endswith(".pdf"):
111
+ docs += load_pdf(file)
112
+ elif name.endswith(".docx"):
113
+ docs += load_docx(file)
114
+ elif name.endswith(".csv"):
115
+ docs += load_csv(file)
116
+ elif name.endswith(".txt"):
117
+ docs += load_txt(file)
118
+ return docs
119
+
120
+ # ---------- Web fetch ----------
121
+ def fetch_web_text(url, timeout=12, retries=2, backoff=1.5):
122
+ for attempt in range(retries + 1):
123
+ try:
124
+ headers = {
125
+ "User-Agent": (
126
+ "Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
127
+ "AppleWebKit/537.36 (KHTML, like Gecko) "
128
+ "Chrome/124.0 Safari/537.36"
129
+ )
130
+ }
131
+ resp = requests.get(url, headers=headers, timeout=timeout)
132
+ resp.raise_for_status()
133
+ soup = BeautifulSoup(resp.text, "html.parser")
134
+ for tag in soup(["script", "style", "noscript"]):
135
+ tag.decompose()
136
+ text = " ".join(soup.get_text(separator=" ").split())
137
+ return [{"source": url, "text": text}]
138
+ except Exception:
139
+ if attempt < retries:
140
+ time.sleep(backoff ** attempt)
141
+ else:
142
+ return []
143
+
144
+ # ---------- Chunking ----------
145
+ def chunk_documents(docs, chunk_size=1000, chunk_overlap=120):
146
+ splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
147
+ chunks = []
148
+ for doc in docs:
149
+ splits = splitter.split_text(doc.get("text", "") or "")
150
+ for idx, chunk in enumerate(splits):
151
+ chunks.append({"source": doc["source"], "chunk_id": f"{doc['source']}_chunk{idx}", "content": chunk})
152
+ return chunks
153
+
154
+ # ---------- Embeddings / Index ----------
155
+ @st.cache_resource(show_spinner=False)
156
+ def load_embedder():
157
+ return SentenceTransformer("all-MiniLM-L6-v2", cache_folder=os.environ.get("SENTENCE_TRANSFORMERS_HOME", HF_HOME))
158
+
159
+ def build_embeddings_index(chunks):
160
+ embedder = load_embedder()
161
+ texts = [c["content"] for c in chunks]
162
+ if not texts:
163
+ return embedder, None
164
+ emb = embedder.encode(texts, show_progress_bar=True, convert_to_numpy=True)
165
+ emb = np.asarray(emb, dtype="float32")
166
+ idx = faiss.IndexFlatL2(emb.shape[14])
167
+ idx.add(emb)
168
+ return embedder, idx
169
+
170
+ def retrieve(query, embedder, index, chunks, top_k=4):
171
+ if index is None or not chunks:
172
+ return []
173
+ q_emb = embedder.encode([query], convert_to_numpy=True)
174
+ q_emb = np.asarray(q_emb, dtype="float32")
175
+ distances, indices = index.search(q_emb, top_k)
176
+ out = []
177
+ for pos, i in enumerate(indices):
178
+ if i >= 0 and i < len(chunks):
179
+ out.append({"chunk": chunks[i], "score": float(distances[pos])})
180
+ return out
181
+
182
+ # ---------- LLM ----------
183
+ MODEL_ID = "MehdiHosseiniMoghadam/AVA-Mistral-7B-V2"
184
+
185
+ @st.cache_resource(show_spinner=False)
186
+ def load_llm():
187
+ cache_dir = os.environ.get("HF_HOME", HF_HOME)
188
+ _ = AutoConfig.from_pretrained(MODEL_ID, cache_dir=cache_dir, trust_remote_code=True)
189
+ tok = AutoTokenizer.from_pretrained(MODEL_ID, cache_dir=cache_dir, trust_remote_code=True)
190
+ model = AutoModelForCausalLM.from_pretrained(MODEL_ID, cache_dir=cache_dir, trust_remote_code=True)
191
+ return pipeline("text-generation", model=model, tokenizer=tok, max_length=1024, do_sample=True, temperature=0.2, trust_remote_code=True, device_map="auto")
192
+
193
+ def answer_with_llm(context_chunks, query, llm):
194
+ context_text = "\n".join(f"[{c['chunk_id']}] {c['content']}" for c in context_chunks)
195
+ prompt = (
196
+ "Answer the following question using ONLY the provided context and cite the chunk ids used.\n"
197
+ f"Question: {query}\n"
198
+ "Context:\n"
199
+ f"{context_text}\n"
200
+ "Answer with citations:"
201
+ )
202
+ out = llm(prompt, max_length=512, num_return_sequences=1)
203
+ return out["generated_text"]
204
+
205
+ # ---------- Sidebar sources ----------
206
+ st.sidebar.header("Data sources")
207
+
208
+ uploaded_files = st.sidebar.file_uploader(
209
+ "Upload documents (PDF, DOCX, TXT, CSV)",
210
+ type=["pdf", "txt", "docx", "csv"],
211
+ accept_multiple_files=True,
212
+ help="Default per-file limit ~200MB; increase via .streamlit/config.toml if needed.",
213
+ )
214
+ with st.sidebar.expander("Upload debug"):
215
+ info = {
216
+ "type": type(uploaded_files).__name__,
217
+ "num_files": (len(uploaded_files) if isinstance(uploaded_files, list) else (1 if uploaded_files else 0)),
218
+ "names": ([f.name for f in uploaded_files] if isinstance(uploaded_files, list) else ([uploaded_files.name] if uploaded_files else [])),
219
+ }
220
+ st.write(info)
221
+
222
+ url_input = st.sidebar.text_area("Web URLs (one per line)", value="", height=120)
223
+
224
+ web_docs = []
225
+ if url_input.strip():
226
+ urls = [u.strip() for u in url_input.splitlines() if u.strip()]
227
+ with st.sidebar.spinner("Fetching web content..."):
228
+ for u in urls:
229
+ web_docs += fetch_web_text(u)
230
+
231
+ file_docs = load_documents(uploaded_files) if uploaded_files else []
232
+ all_docs = file_docs + web_docs
233
+
234
+ if all_docs:
235
+ st.success(f"{len(all_docs)} document(s) loaded from files and URLs.")
236
+ with st.spinner("Chunking and embedding..."):
237
+ st.session_state.chunks = chunk_documents(all_docs, chunk_size=1000, chunk_overlap=120)
238
+ st.session_state.embedder, st.session_state.faiss_index = build_embeddings_index(st.session_state.chunks)
239
+ st.write(f"{len(st.session_state.chunks)} chunks created and indexed.")
240
+ else:
241
+ st.info("Add documents or URLs in the sidebar to start.")
242
+
243
+ # ---------- Chat UI ----------
244
+ for m in st.session_state.messages:
245
+ with st.chat_message(m["role"]):
246
+ st.markdown(m["content"])
247
+
248
+ user_input = st.chat_input("Ask about the loaded documents...")
249
+ if user_input:
250
+ st.session_state.messages.append({"role": "user", "content": user_input})
251
+ with st.chat_message("user"):
252
+ st.markdown(user_input)
253
+
254
+ with st.chat_message("assistant"):
255
+ with st.spinner("Thinking..."):
256
+ if st.session_state.chunks:
257
+ llm = load_llm()
258
+ results = retrieve(user_input, st.session_state.embedder, st.session_state.faiss_index, st.session_state.chunks, top_k=4)
259
+ context_chunks = [r["chunk"] for r in results]
260
+ answer = answer_with_llm(context_chunks, user_input, llm)
261
+ st.markdown(answer)
262
+ sources = "\n".join(f"[{r['chunk']['chunk_id']} from {r['chunk']['source']}]" for r in results) or "No sources (no matches)."
263
+ with st.expander("Sources"):
264
+ st.code(sources)
265
+ else:
266
+ answer = "No documents indexed yet. Add files or URLs in the sidebar and try again."
267
+ st.warning(answer)
268
+ st.session_state.messages.append({"role": "assistant", "content": answer})
269
+
270
+ st.caption("Chat RAG • Mistral-7B (open), FAISS, SentenceTransformers, and Web URLs • Streamlit chat UI")