qminh369 commited on
Commit
a29f0e3
·
verified ·
1 Parent(s): 3c03c76

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +213 -0
  2. requirements.txt +5 -0
app.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pymilvus import MilvusClient, connections
2
+ from pymilvus import Collection, DataType, FieldSchema, CollectionSchema, utility
3
+
4
+ from sentence_transformers import SentenceTransformer
5
+ import re
6
+
7
+ import gradio as gr
8
+
9
+ import requests
10
+ import json
11
+
12
+ def connect_milvus(endpoint, token):
13
+ connections.connect(
14
+ "default",
15
+ uri=CLUSTER_ENDPOINT,
16
+ token=TOKEN
17
+ )
18
+
19
+ client = MilvusClient(
20
+ uri=CLUSTER_ENDPOINT,
21
+ token=TOKEN
22
+ )
23
+ #print(client)
24
+
25
+ def create_vectordb(COLLECTION_NAME, EMBEDDING_DIM=1024):
26
+
27
+ COLLECTION_NAME = "van_ban_phap_luat"
28
+ check_collection = utility.has_collection(COLLECTION_NAME)
29
+ if check_collection:
30
+ drop_result = utility.drop_collection(COLLECTION_NAME)
31
+ print('drop completed!')
32
+
33
+ chunk_id = FieldSchema(
34
+ name="chunk_id",
35
+ dtype=DataType.INT64,
36
+ is_primary=True,
37
+ description="Chunk ID"
38
+ )
39
+
40
+ chunk_ref = FieldSchema(
41
+ name="chunk_ref",
42
+ dtype=DataType.VARCHAR,
43
+ max_length=512,
44
+ description="Chunk ref")
45
+
46
+ chunk_text = FieldSchema(
47
+ name="chunk_text",
48
+ dtype=DataType.VARCHAR,
49
+ max_length=4096,
50
+ description="Chunk text")
51
+
52
+ chunk_embedding = FieldSchema(
53
+ name="chunk_embedding",
54
+ dtype=DataType.FLOAT_VECTOR,
55
+ dim=EMBEDDING_DIM,
56
+ description="Chunk Embedding")
57
+
58
+ schema = CollectionSchema(
59
+ fields=[chunk_id, chunk_ref, chunk_text, chunk_embedding],
60
+ auto_id=False,
61
+ description="Vector Store Chunk using multilingual-e5-large")
62
+
63
+ collection = Collection(
64
+ name=COLLECTION_NAME,
65
+ schema=schema
66
+ )
67
+
68
+ entities = [ids, rules, chunks, embeddings]
69
+ ins_resp = collection.insert(entities)
70
+ ins_resp
71
+
72
+ collection.flush()
73
+
74
+ index_params = {
75
+ "index_type": "IVF_FLAT",
76
+ "metric_type": "COSINE", # L2
77
+ "params": {}
78
+ }
79
+ collection.create_index(
80
+ field_name=chunk_embedding.name,
81
+ index_params=index_params
82
+ )
83
+
84
+ collection.load()
85
+ #return None
86
+
87
+ def load_vectordb(COLLECTION_NAME):
88
+
89
+ collection = Collection(
90
+ name=COLLECTION_NAME,
91
+ #schema=schema
92
+ )
93
+ #print(collection)
94
+ #print(collection.has_index())
95
+ return collection
96
+
97
+ def load_model(model_name):
98
+ model = SentenceTransformer(model_name)
99
+
100
+ return model
101
+
102
+ def search_chunks(query, topk=5):
103
+ # search
104
+ search_params = {
105
+ "metric_type": "COSINE", # L2
106
+ "params": {"level": 2}
107
+ }
108
+
109
+ collection = load_vectordb(COLLECTION_NAME)
110
+
111
+ embed_query = model.encode(query)
112
+ results = collection.search(
113
+ [embed_query],
114
+ #anns_field=chunk_embedding.name,
115
+ anns_field="chunk_embedding",
116
+ param=search_params,
117
+ limit=topk,
118
+ guarantee_timestamp=1,
119
+ output_fields=['chunk_ref', 'chunk_text'] #
120
+ )
121
+
122
+ refs, relevant_chunks = [], []
123
+
124
+ pattern = r"'chunk_ref': '([^']*)', 'chunk_text': '([^']*)'"
125
+
126
+ #print(len(results[0]))
127
+
128
+ for sample in results[0]:
129
+ #print(sample)
130
+ #print(type(sample))
131
+ matches = re.findall(pattern, str(sample))
132
+
133
+ # Lặp qua các kết quả tìm được và thêm vào list
134
+ for match in matches:
135
+ refs.append(match[0])
136
+ relevant_chunks.append(match[1])
137
+
138
+ return refs, relevant_chunks
139
+
140
+ #return results[0]
141
+ #print(len(results[0]))
142
+ #return results
143
+
144
+ def response_saola2m(PROMPT, temperature=0.7):
145
+ #url = 'https://api.fpt.ai/nlp/llm/m2/v1/completions'
146
+ #url = 'https://api.fpt.ai/nlp/llm/api/'
147
+ url = "https://api.fpt.ai/nlp/llm/api/v1/completions"
148
+ headers = {
149
+ 'Authorization': 'Bearer sk-8oIY6XLrokZEJMl6aopCuQ',
150
+ 'Content-Type': 'application/json',
151
+ #'Authorization': 'Bearer EMPTY',
152
+ #'api-key': 'qKFJJ8wP7PFNj3xLXTFsWCXxVqflE3zd',
153
+ #'api-key': 'sk-8oIY6XLrokZEJMl6aopCuQ',
154
+ }
155
+
156
+ data = {
157
+ #"model": "/mnt/data/models/SaoLa2m_envi_mistral",
158
+ "model": "SaoLa2M-instruct",
159
+ "prompt": PROMPT,
160
+ "temperature": temperature,
161
+ "max_tokens": 512
162
+ }
163
+
164
+ response = requests.post(url, headers=headers, json=data)
165
+ response_text = response.text
166
+ response_json = json.loads(response_text)
167
+ result = response_json['choices'][0]['text']
168
+ return result
169
+
170
+ CLUSTER_ENDPOINT = "https://in03-d63abe0e8a8f47b.api.gcp-us-west1.zillizcloud.com"
171
+ TOKEN = "6529d1f59d5e3d38d6135ec9ddf5820a9c38e3db6ca22c53b3aa2ad9c9148e29ef0f11c312bee71f1544da350f85320b598a30f3"
172
+ COLLECTION_NAME = "van_ban_phap_luat"
173
+ MODEL_NAME = 'qminh369/datn-dense_embedding'
174
+ #query = "Cơ quan quản lý giáo dục và đào tạo phải làm gì để đảm bảo pháp luật về giao thông đường bộ được đưa vào chương trình giảng dạy?" # 7
175
+ #query = "Trách nhiệm và quản lý hoạt động giao thông đường bộ đư���c phân công và phân cấp như thế nào?" # 4
176
+ #query = "Những trường hợp cho phép chở người trên xe ô tô chở hàng?" # 21
177
+
178
+ connect_milvus(CLUSTER_ENDPOINT, TOKEN)
179
+ #load_vectordb(COLLECTION_NAME)
180
+ model = load_model(MODEL_NAME)
181
+
182
+ def answer(question):
183
+
184
+ refs, relevant_chunks = search_chunks(question)
185
+ #print(results)
186
+ INSTRUCTION = "Hãy trả lời câu hỏi sau dựa trên thông tin được cung cấp. Nếu thông tin được cung cấp không liên quan dến câu hỏi thì trả về câu trả lời 'Không có thông tin'"
187
+ INPUT_TEXT = "\n".join(relevant_chunks)
188
+
189
+ PROMPT = f"<s>[INST] Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{INSTRUCTION}\n\n{question}\n\n### Input:\n{INPUT_TEXT}\n\n[/INST]### Response:\n"
190
+ print(PROMPT)
191
+
192
+ response = response_saola2m(PROMPT, temperature=0.7)
193
+
194
+ ref = "\n" + "Trích dẫn từ: " + refs[0]
195
+
196
+ response = response + ref
197
+
198
+ return response.strip()
199
+
200
+ def chatbot(question, history=[]):
201
+ output = answer(question)
202
+ history.append((question, output))
203
+ return history, history
204
+
205
+ demo = gr.Interface(
206
+ fn=chatbot,
207
+ inputs=["text", "state"],
208
+ outputs=["chatbot", "state"])
209
+
210
+ demo.queue().launch(share=True)
211
+
212
+
213
+
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ sentence-transformers
2
+ elasticsearch
3
+ grpcio==1.60.0
4
+ pymilvus
5
+ gradio