davidtran999 commited on
Commit
5150cc5
·
1 Parent(s): 519b145

Fix: Copy pure_semantic_search, query_rewriter, redis_cache to backend/hue_portal/core/

Browse files
backend/hue_portal/core/pure_semantic_search.py ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Pure Semantic Search - 100% vector search with multi-query support.
3
+
4
+ This module implements pure semantic search (no BM25) which is the recommended
5
+ approach when using Query Rewrite Strategy + BGE-M3. All top systems have moved
6
+ away from hybrid search (BM25 + Vector) to pure semantic search since Oct 2025.
7
+ """
8
+ import logging
9
+ from typing import List, Tuple, Optional, Dict, Any, Set
10
+ from concurrent.futures import ThreadPoolExecutor, as_completed
11
+ from django.db.models import QuerySet
12
+
13
+ from .embeddings import (
14
+ get_embedding_model,
15
+ generate_embedding,
16
+ cosine_similarity
17
+ )
18
+ from .embedding_utils import load_embedding
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+ # Minimum vector score threshold
23
+ DEFAULT_MIN_VECTOR_SCORE = 0.1
24
+
25
+
26
+ def get_vector_scores(
27
+ queryset: QuerySet,
28
+ query: str,
29
+ top_k: int = 20
30
+ ) -> List[Tuple[Any, float]]:
31
+ """
32
+ Get vector similarity scores for queryset.
33
+
34
+ This is extracted from hybrid_search.py for use in pure semantic search.
35
+
36
+ Args:
37
+ queryset: Django QuerySet to search.
38
+ query: Search query string.
39
+ top_k: Maximum number of results.
40
+
41
+ Returns:
42
+ List of (object, vector_score) tuples.
43
+ """
44
+ if not query or not query.strip():
45
+ return []
46
+
47
+ # Generate query embedding
48
+ model = get_embedding_model()
49
+ if model is None:
50
+ return []
51
+
52
+ query_embedding = generate_embedding(query, model=model)
53
+ if query_embedding is None:
54
+ return []
55
+
56
+ # Get all objects with embeddings
57
+ all_objects = list(queryset)
58
+ if not all_objects:
59
+ return []
60
+
61
+ # Check dimension compatibility first
62
+ query_dim = len(query_embedding)
63
+ dimension_mismatch = False
64
+
65
+ # Calculate similarities
66
+ scores = []
67
+ for obj in all_objects:
68
+ obj_embedding = load_embedding(obj)
69
+ if obj_embedding is not None:
70
+ obj_dim = len(obj_embedding)
71
+ if obj_dim != query_dim:
72
+ # Dimension mismatch - skip vector search for this object
73
+ if not dimension_mismatch:
74
+ logger.warning(
75
+ f"Dimension mismatch: query={query_dim}, stored={obj_dim}. Skipping vector search."
76
+ )
77
+ dimension_mismatch = True
78
+ continue
79
+ similarity = cosine_similarity(query_embedding, obj_embedding)
80
+ if similarity >= DEFAULT_MIN_VECTOR_SCORE:
81
+ scores.append((obj, similarity))
82
+
83
+ # If dimension mismatch detected, return empty
84
+ if dimension_mismatch and not scores:
85
+ return []
86
+
87
+ # Sort by score descending
88
+ scores.sort(key=lambda x: x[1], reverse=True)
89
+ return scores[:top_k * 2] # Get more for merging with other queries
90
+
91
+
92
+ def calculate_exact_match_boost(obj: Any, query: str, text_fields: List[str]) -> float:
93
+ """
94
+ Calculate boost score for exact keyword matches in title/name fields.
95
+
96
+ This ensures exact matches are prioritized even in pure semantic search.
97
+
98
+ Args:
99
+ obj: Django model instance.
100
+ query: Search query string.
101
+ text_fields: List of field names to check (first 2 are usually title/name).
102
+
103
+ Returns:
104
+ Boost score (0.0 to 1.0).
105
+ """
106
+ if not query or not text_fields:
107
+ return 0.0
108
+
109
+ query_lower = query.lower().strip()
110
+ # Extract key phrases (2-3 word combinations) from query
111
+ query_words = query_lower.split()
112
+ key_phrases = []
113
+ for i in range(len(query_words) - 1):
114
+ phrase = " ".join(query_words[i:i+2])
115
+ if len(phrase) > 3:
116
+ key_phrases.append(phrase)
117
+ for i in range(len(query_words) - 2):
118
+ phrase = " ".join(query_words[i:i+3])
119
+ if len(phrase) > 5:
120
+ key_phrases.append(phrase)
121
+
122
+ # Also add individual words (longer than 2 chars)
123
+ query_words_set = set(word for word in query_words if len(word) > 2)
124
+
125
+ boost = 0.0
126
+
127
+ # Check primary fields (title, name) for exact matches
128
+ # First 2 fields are usually title/name
129
+ for field in text_fields[:2]:
130
+ if hasattr(obj, field):
131
+ field_value = str(getattr(obj, field, "")).lower()
132
+ if field_value:
133
+ # Check for key phrases first (highest priority)
134
+ for phrase in key_phrases:
135
+ if phrase in field_value:
136
+ # Major boost for phrase match
137
+ boost += 0.5
138
+ # Extra boost if it's the exact field value
139
+ if field_value.strip() == phrase.strip():
140
+ boost += 0.3
141
+
142
+ # Check for full query match
143
+ if query_lower in field_value:
144
+ boost += 0.4
145
+
146
+ # Count matched individual words
147
+ matched_words = sum(1 for word in query_words_set if word in field_value)
148
+ if matched_words > 0:
149
+ # Moderate boost for word matches
150
+ boost += 0.1 * min(matched_words, 3) # Cap at 3 words
151
+
152
+ return min(boost, 1.0) # Cap at 1.0 for very strong matches
153
+
154
+
155
+ def parallel_vector_search(
156
+ queries: List[str],
157
+ queryset: QuerySet,
158
+ top_k_per_query: int = 5,
159
+ final_top_k: int = 7,
160
+ text_fields: Optional[List[str]] = None
161
+ ) -> List[Tuple[Any, float]]:
162
+ """
163
+ Search with multiple queries in parallel, then merge results.
164
+
165
+ This is the core of Query Rewrite Strategy - run multiple vector searches
166
+ in parallel and merge results to get the best documents.
167
+
168
+ Args:
169
+ queries: List of rewritten queries (3-5 queries from Query Rewrite).
170
+ queryset: Django QuerySet to search.
171
+ top_k_per_query: Top K results per query (default: 5).
172
+ final_top_k: Final top K results after merging (default: 7).
173
+ text_fields: Optional list of field names for exact match boost.
174
+
175
+ Returns:
176
+ List of (object, combined_score) tuples, sorted by score descending.
177
+
178
+ Example:
179
+ queries = [
180
+ "nội dung điều 12",
181
+ "quy định điều 12",
182
+ "điều 12 quy định về"
183
+ ]
184
+ results = parallel_vector_search(queries, LegalSection.objects.all())
185
+ # Returns top 7 sections with highest combined scores
186
+ """
187
+ if not queries or not queries[0].strip():
188
+ return []
189
+
190
+ if len(queries) == 1:
191
+ # Single query - use direct vector search
192
+ return _single_query_search(queries[0], queryset, top_k=final_top_k, text_fields=text_fields)
193
+
194
+ # Multiple queries - run in parallel
195
+ all_results: Dict[Any, float] = {} # object -> max_score
196
+
197
+ # Use ThreadPoolExecutor for parallel searches
198
+ with ThreadPoolExecutor(max_workers=min(len(queries), 5)) as executor:
199
+ # Submit all searches
200
+ future_to_query = {
201
+ executor.submit(get_vector_scores, queryset, query, top_k=top_k_per_query): query
202
+ for query in queries
203
+ }
204
+
205
+ # Collect results as they complete
206
+ for future in as_completed(future_to_query):
207
+ query = future_to_query[future]
208
+ try:
209
+ results = future.result()
210
+ # Merge results: use max score for each object
211
+ for obj, score in results:
212
+ if obj in all_results:
213
+ # Keep the maximum score from all queries
214
+ all_results[obj] = max(all_results[obj], score)
215
+ else:
216
+ all_results[obj] = score
217
+ except Exception as e:
218
+ logger.warning(f"[PARALLEL_SEARCH] Error searching with query '{query}': {e}")
219
+
220
+ # Apply exact match boost if text_fields provided
221
+ if text_fields:
222
+ boosted_results = []
223
+ for obj, score in all_results.items():
224
+ boost = calculate_exact_match_boost(obj, queries[0], text_fields) # Use first query for boost
225
+ # Combine vector score with exact match boost (weighted)
226
+ combined_score = score * 0.8 + boost * 0.2 # 80% vector, 20% exact match
227
+ boosted_results.append((obj, combined_score))
228
+ all_results_list = boosted_results
229
+ else:
230
+ all_results_list = list(all_results.items())
231
+
232
+ # Sort by score descending
233
+ all_results_list.sort(key=lambda x: x[1], reverse=True)
234
+
235
+ return all_results_list[:final_top_k]
236
+
237
+
238
+ def _single_query_search(
239
+ query: str,
240
+ queryset: QuerySet,
241
+ top_k: int = 20,
242
+ text_fields: Optional[List[str]] = None
243
+ ) -> List[Tuple[Any, float]]:
244
+ """
245
+ Single query vector search with exact match boost.
246
+
247
+ Args:
248
+ query: Search query string.
249
+ queryset: Django QuerySet to search.
250
+ top_k: Maximum number of results.
251
+ text_fields: Optional list of field names for exact match boost.
252
+
253
+ Returns:
254
+ List of (object, score) tuples, sorted by score descending.
255
+ """
256
+ # Get vector scores
257
+ vector_results = get_vector_scores(queryset, query, top_k=top_k)
258
+
259
+ if not text_fields:
260
+ return vector_results[:top_k]
261
+
262
+ # Apply exact match boost
263
+ boosted_results = []
264
+ for obj, score in vector_results:
265
+ boost = calculate_exact_match_boost(obj, query, text_fields)
266
+ # Combine vector score with exact match boost (weighted)
267
+ combined_score = score * 0.8 + boost * 0.2 # 80% vector, 20% exact match
268
+ boosted_results.append((obj, combined_score))
269
+
270
+ # Sort by combined score
271
+ boosted_results.sort(key=lambda x: x[1], reverse=True)
272
+ return boosted_results[:top_k]
273
+
274
+
275
+ def pure_semantic_search(
276
+ queries: List[str],
277
+ queryset: QuerySet,
278
+ top_k: int = 20,
279
+ text_fields: Optional[List[str]] = None
280
+ ) -> List[Any]:
281
+ """
282
+ Pure semantic search (100% vector, no BM25).
283
+
284
+ This is the recommended search strategy when using Query Rewrite + BGE-M3.
285
+ All top systems have moved away from hybrid search to pure semantic since Oct 2025.
286
+
287
+ Args:
288
+ queries: List of queries (1 query or 3-5 queries from Query Rewrite).
289
+ queryset: Django QuerySet to search.
290
+ top_k: Maximum number of results.
291
+ text_fields: Optional list of field names for exact match boost.
292
+
293
+ Returns:
294
+ List of objects sorted by score (highest first).
295
+
296
+ Usage:
297
+ # Single query
298
+ results = pure_semantic_search(["mức phạt vi phạm"], queryset, top_k=20)
299
+
300
+ # Multiple queries (from Query Rewrite)
301
+ rewritten_queries = query_rewriter.rewrite_query("mức phạt vi phạm")
302
+ results = pure_semantic_search(rewritten_queries, queryset, top_k=20)
303
+ """
304
+ if not queries:
305
+ return []
306
+
307
+ if len(queries) == 1:
308
+ # Single query - direct search
309
+ results = _single_query_search(queries[0], queryset, top_k=top_k, text_fields=text_fields)
310
+ else:
311
+ # Multiple queries - parallel search
312
+ results = parallel_vector_search(
313
+ queries,
314
+ queryset,
315
+ top_k_per_query=max(5, top_k // len(queries)),
316
+ final_top_k=top_k,
317
+ text_fields=text_fields
318
+ )
319
+
320
+ # Return just the objects (without scores)
321
+ return [obj for obj, _ in results]
322
+
backend/hue_portal/core/query_rewriter.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Query Rewriter - Rewrite user queries into 3-5 optimized legal queries.
3
+
4
+ This module implements the Query Rewrite Strategy - the "best practice" approach
5
+ used by top legal RAG systems in 2025, achieving >99.9% accuracy.
6
+ """
7
+ import os
8
+ import logging
9
+ import hashlib
10
+ import json
11
+ from typing import List, Dict, Any, Optional
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ class QueryRewriter:
17
+ """
18
+ Rewrite user queries into 3-5 optimized legal queries for better search results.
19
+
20
+ This is the core of Query Rewrite Strategy - instead of using LLM to suggest
21
+ documents (which can hallucinate), we rewrite the query into multiple variations
22
+ and use pure vector search to find the best documents.
23
+ """
24
+
25
+ def __init__(self, llm_generator=None, use_cache: bool = True):
26
+ """
27
+ Initialize Query Rewriter.
28
+
29
+ Args:
30
+ llm_generator: Optional LLMGenerator instance. If None, will get from llm_integration.
31
+ use_cache: Whether to use Redis cache for query rewrites (default: True).
32
+ """
33
+ if llm_generator is None:
34
+ try:
35
+ from hue_portal.chatbot.llm_integration import get_llm_generator
36
+ self.llm_generator = get_llm_generator()
37
+ except Exception as e:
38
+ logger.warning(f"[QUERY_REWRITER] Failed to get LLM generator: {e}")
39
+ self.llm_generator = None
40
+ else:
41
+ self.llm_generator = llm_generator
42
+
43
+ # Initialize Redis cache if available
44
+ self.use_cache = use_cache
45
+ self.cache = None
46
+ if self.use_cache:
47
+ try:
48
+ from hue_portal.core.redis_cache import get_redis_cache
49
+ self.cache = get_redis_cache()
50
+ if not self.cache.is_available():
51
+ logger.info("[QUERY_REWRITER] Redis cache not available, caching disabled")
52
+ self.cache = None
53
+ except Exception as e:
54
+ logger.warning(f"[QUERY_REWRITER] Failed to initialize cache: {e}")
55
+ self.cache = None
56
+
57
+ def rewrite_query(
58
+ self,
59
+ user_query: str,
60
+ context: Optional[List[Dict[str, str]]] = None,
61
+ max_queries: int = 5,
62
+ min_queries: int = 3
63
+ ) -> List[str]:
64
+ """
65
+ Rewrite a user query into 3-5 optimized legal queries.
66
+
67
+ Args:
68
+ user_query: Original user query string.
69
+ context: Optional conversation context (list of {role, content} dicts).
70
+ max_queries: Maximum number of queries to generate (default: 5).
71
+ min_queries: Minimum number of queries to generate (default: 3).
72
+
73
+ Returns:
74
+ List of rewritten queries (3-5 queries).
75
+
76
+ Examples:
77
+ Input: "điều 12 nói gì"
78
+ Output: [
79
+ "nội dung điều 12",
80
+ "quy định điều 12",
81
+ "điều 12 quy định về",
82
+ "điều 12 quy định gì",
83
+ "điều 12 quy định như thế nào"
84
+ ]
85
+
86
+ Input: "mức phạt vi phạm"
87
+ Output: [
88
+ "mức phạt vi phạm",
89
+ "khung hình phạt",
90
+ "mức xử phạt",
91
+ "phạt vi phạm",
92
+ "xử phạt vi phạm"
93
+ ]
94
+ """
95
+ if not user_query or not user_query.strip():
96
+ return []
97
+
98
+ user_query = user_query.strip()
99
+
100
+ # Check cache first
101
+ if self.cache and self.cache.is_available():
102
+ cache_key = f"query_rewrite:{self.get_cache_key(user_query, context=context)}"
103
+ cached_queries = self.cache.get(cache_key)
104
+ if cached_queries and isinstance(cached_queries, list):
105
+ logger.info(f"[QUERY_REWRITER] ✅ Cache hit for query rewrite")
106
+ return cached_queries[:max_queries]
107
+
108
+ # Try LLM-based rewrite first
109
+ if self.llm_generator and self.llm_generator.is_available():
110
+ try:
111
+ rewritten = self._rewrite_with_llm(
112
+ user_query,
113
+ context=context,
114
+ max_queries=max_queries,
115
+ min_queries=min_queries
116
+ )
117
+ if rewritten and len(rewritten) >= min_queries:
118
+ logger.info(f"[QUERY_REWRITER] ✅ LLM rewrite: {len(rewritten)} queries")
119
+ final_queries = rewritten[:max_queries]
120
+
121
+ # Cache the result
122
+ if self.cache and self.cache.is_available():
123
+ cache_key = f"query_rewrite:{self.get_cache_key(user_query, context=context)}"
124
+ self.cache.set(cache_key, final_queries, ttl_seconds=CACHE_QUERY_REWRITE_TTL)
125
+ logger.debug(f"[QUERY_REWRITER] Cached query rewrite (TTL: {CACHE_QUERY_REWRITE_TTL}s)")
126
+
127
+ return final_queries
128
+ except Exception as e:
129
+ logger.warning(f"[QUERY_REWRITER] LLM rewrite failed: {e}, using fallback")
130
+
131
+ # Fallback to rule-based rewrite
132
+ return self._rewrite_fallback(user_query, max_queries=max_queries, min_queries=min_queries)
133
+
134
+ def _rewrite_with_llm(
135
+ self,
136
+ user_query: str,
137
+ context: Optional[List[Dict[str, str]]] = None,
138
+ max_queries: int = 5,
139
+ min_queries: int = 3
140
+ ) -> List[str]:
141
+ """
142
+ Rewrite query using LLM.
143
+
144
+ Args:
145
+ user_query: Original user query.
146
+ context: Optional conversation context.
147
+ max_queries: Maximum queries to generate.
148
+ min_queries: Minimum queries to generate.
149
+
150
+ Returns:
151
+ List of rewritten queries.
152
+ """
153
+ # Build context summary
154
+ context_text = ""
155
+ if context:
156
+ recent_user_messages = [
157
+ msg.get("content", "")
158
+ for msg in context[-3:] # Last 3 messages
159
+ if msg.get("role") == "user"
160
+ ]
161
+ if recent_user_messages:
162
+ context_text = " ".join(recent_user_messages)
163
+
164
+ # Build prompt for query rewriting
165
+ prompt = (
166
+ "Bạn là trợ lý pháp luật chuyên nghiệp. Nhiệm vụ của bạn là viết lại câu hỏi của người dùng "
167
+ "thành {max_queries} câu hỏi chuẩn pháp lý tối ưu nhất để tìm kiếm trong cơ sở dữ liệu văn bản pháp luật.\n\n"
168
+ "Câu hỏi gốc: \"{user_query}\"\n\n"
169
+ "{context_section}"
170
+ "Yêu cầu:\n"
171
+ "1. Viết lại thành {max_queries} câu hỏi khác nhau, mỗi câu hỏi tập trung vào một khía cạnh của vấn đề\n"
172
+ "2. Sử dụng thuật ngữ pháp lý chuẩn (ví dụ: 'quy định', 'điều', 'khoản', 'mức phạt', 'khung hình phạt')\n"
173
+ "3. Các câu hỏi nên bao quát nhiều cách diễn đạt khác nhau của cùng một vấn đề\n"
174
+ "4. Giữ nguyên ý nghĩa chính của câu hỏi gốc\n"
175
+ "5. Mỗi câu hỏi nên ngắn gọn, rõ ràng (10-20 từ)\n\n"
176
+ "Trả về JSON với dạng:\n"
177
+ "{{\n"
178
+ ' "queries": ["câu hỏi 1", "câu hỏi 2", "câu hỏi 3", ...]\n'
179
+ "}}\n"
180
+ "Chỉ in JSON, không thêm lời giải thích khác."
181
+ ).format(
182
+ max_queries=max_queries,
183
+ user_query=user_query,
184
+ context_section=(
185
+ f"Ngữ cảnh cuộc hội thoại: {context_text}\n\n"
186
+ if context_text else ""
187
+ )
188
+ )
189
+
190
+ # Generate with LLM
191
+ raw = self.llm_generator._generate_from_prompt(prompt)
192
+ if not raw:
193
+ return []
194
+
195
+ # Parse JSON response
196
+ parsed = self.llm_generator._extract_json_payload(raw)
197
+ if not parsed:
198
+ return []
199
+
200
+ queries = parsed.get("queries") or []
201
+ if not isinstance(queries, list):
202
+ return []
203
+
204
+ # Filter and validate queries
205
+ valid_queries = []
206
+ for q in queries:
207
+ if isinstance(q, str):
208
+ q = q.strip()
209
+ if q and len(q) > 3: # Minimum length
210
+ valid_queries.append(q)
211
+
212
+ # Ensure we have at least min_queries
213
+ if len(valid_queries) < min_queries:
214
+ # Add original query if not already present
215
+ if user_query not in valid_queries:
216
+ valid_queries.insert(0, user_query)
217
+
218
+ # Generate additional variations using fallback
219
+ fallback_queries = self._rewrite_fallback(
220
+ user_query,
221
+ max_queries=max_queries - len(valid_queries),
222
+ min_queries=0
223
+ )
224
+ valid_queries.extend(fallback_queries)
225
+
226
+ # Remove duplicates while preserving order
227
+ seen = set()
228
+ unique_queries = []
229
+ for q in valid_queries:
230
+ q_lower = q.lower()
231
+ if q_lower not in seen:
232
+ seen.add(q_lower)
233
+ unique_queries.append(q)
234
+
235
+ return unique_queries[:max_queries]
236
+
237
+ def _rewrite_fallback(
238
+ self,
239
+ user_query: str,
240
+ max_queries: int = 5,
241
+ min_queries: int = 3
242
+ ) -> List[str]:
243
+ """
244
+ Fallback rule-based query rewriting.
245
+
246
+ This generates query variations using simple patterns when LLM is not available.
247
+
248
+ Args:
249
+ user_query: Original user query.
250
+ max_queries: Maximum queries to generate.
251
+ min_queries: Minimum queries to generate.
252
+
253
+ Returns:
254
+ List of rewritten queries.
255
+ """
256
+ queries = [user_query] # Always include original
257
+
258
+ query_lower = user_query.lower()
259
+ query_words = query_lower.split()
260
+
261
+ # Pattern 1: Add "quy định" if not present
262
+ if "quy định" not in query_lower and "quy định" not in query_lower:
263
+ if len(query_words) > 1:
264
+ queries.append(f"quy định {user_query}")
265
+ queries.append(f"{user_query} quy định")
266
+
267
+ # Pattern 2: Add "nội dung" for "điều" queries
268
+ if "điều" in query_lower:
269
+ # Extract điều number if possible
270
+ for word in query_words:
271
+ if "điều" in word.lower():
272
+ idx = query_words.index(word)
273
+ if idx + 1 < len(query_words):
274
+ next_word = query_words[idx + 1]
275
+ queries.append(f"nội dung điều {next_word}")
276
+ queries.append(f"quy định điều {next_word}")
277
+ break
278
+
279
+ # Pattern 3: Add "mức phạt" variations for fine-related queries
280
+ if any(kw in query_lower for kw in ["phạt", "vi phạm", "xử phạt"]):
281
+ if "mức phạt" not in query_lower:
282
+ queries.append(f"mức phạt {user_query}")
283
+ if "khung hình phạt" not in query_lower:
284
+ queries.append(f"khung hình phạt {user_query}")
285
+
286
+ # Pattern 4: Add "thủ tục" variations for procedure queries
287
+ if any(kw in query_lower for kw in ["thủ tục", "hồ sơ", "giấy tờ"]):
288
+ if "thủ tục" not in query_lower:
289
+ queries.append(f"thủ tục {user_query}")
290
+
291
+ # Remove duplicates while preserving order
292
+ seen = set()
293
+ unique_queries = []
294
+ for q in queries:
295
+ q_lower = q.lower()
296
+ if q_lower not in seen:
297
+ seen.add(q_lower)
298
+ unique_queries.append(q)
299
+
300
+ # Ensure minimum queries
301
+ while len(unique_queries) < min_queries:
302
+ # Add simple variations
303
+ if len(query_words) > 1:
304
+ # Reverse word order
305
+ reversed_query = " ".join(reversed(query_words))
306
+ if reversed_query.lower() not in seen:
307
+ unique_queries.append(reversed_query)
308
+ seen.add(reversed_query.lower())
309
+ else:
310
+ break
311
+
312
+ return unique_queries[:max_queries]
313
+
314
+ def get_cache_key(self, user_query: str, context: Optional[List[Dict[str, str]]] = None) -> str:
315
+ """
316
+ Generate cache key for query rewrite.
317
+
318
+ Args:
319
+ user_query: Original user query.
320
+ context: Optional conversation context.
321
+
322
+ Returns:
323
+ Cache key string.
324
+ """
325
+ # Create hash from query and context
326
+ cache_data = {
327
+ "query": user_query.strip().lower(),
328
+ "context": [
329
+ {"role": msg.get("role"), "content": msg.get("content", "")[:100]}
330
+ for msg in (context or [])[-3:] # Last 3 messages only
331
+ ]
332
+ }
333
+ cache_str = json.dumps(cache_data, sort_keys=True, ensure_ascii=False)
334
+ return hashlib.sha256(cache_str.encode("utf-8")).hexdigest()
335
+
336
+
337
+ def get_query_rewriter(llm_generator=None) -> QueryRewriter:
338
+ """
339
+ Get or create QueryRewriter instance.
340
+
341
+ Args:
342
+ llm_generator: Optional LLMGenerator instance.
343
+
344
+ Returns:
345
+ QueryRewriter instance.
346
+ """
347
+ return QueryRewriter(llm_generator=llm_generator)
348
+
backend/hue_portal/core/redis_cache.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Redis Cache Layer for Query Rewrite and Prefetch Results.
3
+
4
+ This module provides Redis caching for:
5
+ - Query rewrite results (1000 queries, TTL 1 hour)
6
+ - Prefetch results by document_code (TTL 30 minutes)
7
+
8
+ Supports Upstash and Railway Redis free tier.
9
+ """
10
+ import os
11
+ import logging
12
+ import json
13
+ from typing import Optional, Dict, Any, List
14
+ from datetime import timedelta
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+ # Try to import redis
19
+ try:
20
+ import redis
21
+ REDIS_AVAILABLE = True
22
+ except ImportError:
23
+ REDIS_AVAILABLE = False
24
+ logger.warning("[REDIS] redis package not installed. Install with: pip install redis")
25
+
26
+
27
+ class RedisCache:
28
+ """
29
+ Redis cache manager for query rewrites and prefetch results.
30
+
31
+ Supports graceful degradation if Redis is unavailable.
32
+ """
33
+
34
+ def __init__(self, redis_url: Optional[str] = None):
35
+ """
36
+ Initialize Redis cache.
37
+
38
+ Args:
39
+ redis_url: Redis connection URL. If None, reads from REDIS_URL env var.
40
+ """
41
+ self.redis_url = redis_url or os.environ.get("REDIS_URL")
42
+ self.client: Optional[redis.Redis] = None
43
+ self._connected = False
44
+
45
+ if not REDIS_AVAILABLE:
46
+ logger.warning("[REDIS] Redis package not available, caching disabled")
47
+ return
48
+
49
+ if not self.redis_url:
50
+ logger.warning("[REDIS] REDIS_URL not configured, caching disabled")
51
+ return
52
+
53
+ self._connect()
54
+
55
+ def _connect(self) -> None:
56
+ """Connect to Redis server."""
57
+ if not REDIS_AVAILABLE or not self.redis_url:
58
+ return
59
+
60
+ try:
61
+ # Parse Redis URL
62
+ # Format: redis://[:password@]host[:port][/db]
63
+ # Or: rediss:// for SSL
64
+ self.client = redis.from_url(
65
+ self.redis_url,
66
+ decode_responses=True, # Auto-decode strings
67
+ socket_connect_timeout=5,
68
+ socket_timeout=5,
69
+ retry_on_timeout=True,
70
+ health_check_interval=30
71
+ )
72
+
73
+ # Test connection
74
+ self.client.ping()
75
+ self._connected = True
76
+ logger.info("[REDIS] ✅ Connected to Redis successfully")
77
+ except Exception as e:
78
+ logger.warning(f"[REDIS] Failed to connect to Redis: {e}, caching disabled")
79
+ self.client = None
80
+ self._connected = False
81
+
82
+ def is_available(self) -> bool:
83
+ """Check if Redis is available and connected."""
84
+ if not self._connected or not self.client:
85
+ return False
86
+
87
+ try:
88
+ self.client.ping()
89
+ return True
90
+ except Exception:
91
+ self._connected = False
92
+ return False
93
+
94
+ def get(self, key: str) -> Optional[Any]:
95
+ """
96
+ Get value from cache.
97
+
98
+ Args:
99
+ key: Cache key.
100
+
101
+ Returns:
102
+ Cached value or None if not found.
103
+ """
104
+ if not self.is_available():
105
+ return None
106
+
107
+ try:
108
+ value = self.client.get(key)
109
+ if value is None:
110
+ return None
111
+
112
+ # Try to parse as JSON
113
+ try:
114
+ return json.loads(value)
115
+ except (json.JSONDecodeError, TypeError):
116
+ # Return as string if not JSON
117
+ return value
118
+ except Exception as e:
119
+ logger.warning(f"[REDIS] Error getting key '{key}': {e}")
120
+ return None
121
+
122
+ def set(
123
+ self,
124
+ key: str,
125
+ value: Any,
126
+ ttl_seconds: Optional[int] = None
127
+ ) -> bool:
128
+ """
129
+ Set value in cache.
130
+
131
+ Args:
132
+ key: Cache key.
133
+ value: Value to cache (will be JSON-encoded if dict/list).
134
+ ttl_seconds: Time to live in seconds. If None, no expiration.
135
+
136
+ Returns:
137
+ True if successful, False otherwise.
138
+ """
139
+ if not self.is_available():
140
+ return False
141
+
142
+ try:
143
+ # Serialize value to JSON if it's a dict/list
144
+ if isinstance(value, (dict, list)):
145
+ serialized = json.dumps(value, ensure_ascii=False)
146
+ else:
147
+ serialized = str(value)
148
+
149
+ if ttl_seconds:
150
+ self.client.setex(key, ttl_seconds, serialized)
151
+ else:
152
+ self.client.set(key, serialized)
153
+
154
+ return True
155
+ except Exception as e:
156
+ logger.warning(f"[REDIS] Error setting key '{key}': {e}")
157
+ return False
158
+
159
+ def delete(self, key: str) -> bool:
160
+ """
161
+ Delete key from cache.
162
+
163
+ Args:
164
+ key: Cache key.
165
+
166
+ Returns:
167
+ True if successful, False otherwise.
168
+ """
169
+ if not self.is_available():
170
+ return False
171
+
172
+ try:
173
+ self.client.delete(key)
174
+ return True
175
+ except Exception as e:
176
+ logger.warning(f"[REDIS] Error deleting key '{key}': {e}")
177
+ return False
178
+
179
+ def exists(self, key: str) -> bool:
180
+ """
181
+ Check if key exists in cache.
182
+
183
+ Args:
184
+ key: Cache key.
185
+
186
+ Returns:
187
+ True if key exists, False otherwise.
188
+ """
189
+ if not self.is_available():
190
+ return False
191
+
192
+ try:
193
+ return self.client.exists(key) > 0
194
+ except Exception:
195
+ return False
196
+
197
+ def clear_pattern(self, pattern: str) -> int:
198
+ """
199
+ Clear all keys matching pattern.
200
+
201
+ Args:
202
+ pattern: Redis key pattern (e.g., "query_rewrite:*").
203
+
204
+ Returns:
205
+ Number of keys deleted.
206
+ """
207
+ if not self.is_available():
208
+ return 0
209
+
210
+ try:
211
+ keys = self.client.keys(pattern)
212
+ if keys:
213
+ return self.client.delete(*keys)
214
+ return 0
215
+ except Exception as e:
216
+ logger.warning(f"[REDIS] Error clearing pattern '{pattern}': {e}")
217
+ return 0
218
+
219
+
220
+ # Singleton instance
221
+ _redis_cache_instance: Optional[RedisCache] = None
222
+
223
+
224
+ def get_redis_cache(redis_url: Optional[str] = None) -> RedisCache:
225
+ """
226
+ Get or create Redis cache instance.
227
+
228
+ Args:
229
+ redis_url: Optional Redis URL. If None, uses REDIS_URL env var.
230
+
231
+ Returns:
232
+ RedisCache instance.
233
+ """
234
+ global _redis_cache_instance
235
+
236
+ if _redis_cache_instance is None:
237
+ _redis_cache_instance = RedisCache(redis_url=redis_url)
238
+
239
+ return _redis_cache_instance
240
+