Rerank documents and force summary for policy makers
Browse files
climateqa/engine/chains/retrieve_documents.py
CHANGED
|
@@ -57,107 +57,135 @@ def query_retriever(question):
|
|
| 57 |
"""Just a dummy tool to simulate the retriever query"""
|
| 58 |
return question
|
| 59 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
remaining_questions = state["remaining_questions"][1:]
|
| 78 |
-
|
| 79 |
-
# ToolMessage(f"Retrieving documents for question: {current_question['question']}",tool_call_id = "retriever")
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
# # There are several options to get the final top k
|
| 83 |
-
# # Option 1 - Get 100 documents by question and rerank by question
|
| 84 |
-
# # Option 2 - Get 100/n documents by question and rerank the total
|
| 85 |
-
# if rerank_by_question:
|
| 86 |
-
# k_by_question = divide_into_parts(k_final,len(questions))
|
| 87 |
-
if "documents" in state and state["documents"] is not None:
|
| 88 |
-
docs = state["documents"]
|
| 89 |
-
else:
|
| 90 |
-
docs = []
|
| 91 |
|
| 92 |
|
| 93 |
-
|
| 94 |
-
|
|
|
|
| 95 |
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
index = current_question["index"]
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
await log_event({"question":question,"sources":sources,"index":index},"log_retriever",config)
|
| 102 |
|
|
|
|
| 103 |
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
sources = sources,
|
| 111 |
-
min_size = 200,
|
| 112 |
-
k_summary = k_summary,
|
| 113 |
-
k_total = k_before_reranking,
|
| 114 |
-
threshold = 0.5,
|
| 115 |
-
)
|
| 116 |
-
docs_question = await retriever.ainvoke(question,config)
|
| 117 |
|
| 118 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
|
| 120 |
-
keywords = keywords_extraction.invoke(question)["keywords"]
|
| 121 |
-
openalex_query = " AND ".join(keywords)
|
| 122 |
|
| 123 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
|
| 125 |
-
retriever_openalex = OpenAlexRetriever(
|
| 126 |
-
min_year = state.get("min_year",1960),
|
| 127 |
-
max_year = state.get("max_year",None),
|
| 128 |
-
k = k_before_reranking
|
| 129 |
-
)
|
| 130 |
-
docs_question = await retriever_openalex.ainvoke(openalex_query,config)
|
| 131 |
|
| 132 |
-
else:
|
| 133 |
-
raise Exception(f"Index {index} not found in the routing index")
|
| 134 |
-
|
| 135 |
-
# Rerank
|
| 136 |
-
if reranker is not None:
|
| 137 |
-
with suppress_output():
|
| 138 |
-
docs_question = rerank_docs(reranker,docs_question,question)
|
| 139 |
-
else:
|
| 140 |
-
# Add a default reranking score
|
| 141 |
-
for doc in docs_question:
|
| 142 |
-
doc.metadata["reranking_score"] = doc.metadata["similarity_score"]
|
| 143 |
-
|
| 144 |
-
# If rerank by question we select the top documents for each question
|
| 145 |
-
if rerank_by_question:
|
| 146 |
-
docs_question = docs_question[:k_by_question]
|
| 147 |
-
|
| 148 |
-
# Add sources used in the metadata
|
| 149 |
-
for doc in docs_question:
|
| 150 |
-
doc.metadata["sources_used"] = sources
|
| 151 |
-
doc.metadata["question_used"] = question
|
| 152 |
-
doc.metadata["index_used"] = index
|
| 153 |
-
|
| 154 |
-
# Add to the list of docs
|
| 155 |
-
docs.extend(docs_question)
|
| 156 |
-
|
| 157 |
-
# Sorting the list in descending order by rerank_score
|
| 158 |
-
docs = sorted(docs, key=lambda x: x.metadata["reranking_score"], reverse=True)
|
| 159 |
-
new_state = {"documents":docs,"remaining_questions":remaining_questions}
|
| 160 |
-
return new_state
|
| 161 |
-
|
| 162 |
-
return retrieve_documents
|
| 163 |
|
|
|
|
| 57 |
"""Just a dummy tool to simulate the retriever query"""
|
| 58 |
return question
|
| 59 |
|
| 60 |
+
def _add_sources_used_in_metadata(docs,sources,question,index):
|
| 61 |
+
for doc in docs:
|
| 62 |
+
doc.metadata["sources_used"] = sources
|
| 63 |
+
doc.metadata["question_used"] = question
|
| 64 |
+
doc.metadata["index_used"] = index
|
| 65 |
+
return docs
|
| 66 |
+
|
| 67 |
+
def _get_k_summary_by_question(n_questions):
|
| 68 |
+
if n_questions == 0:
|
| 69 |
+
return 0
|
| 70 |
+
elif n_questions == 1:
|
| 71 |
+
return 5
|
| 72 |
+
elif n_questions == 2:
|
| 73 |
+
return 3
|
| 74 |
+
elif n_questions == 3:
|
| 75 |
+
return 2
|
| 76 |
+
else:
|
| 77 |
+
return 1
|
| 78 |
+
|
| 79 |
|
| 80 |
+
# The chain callback is not necessary, but it propagates the langchain callbacks to the astream_events logger to display intermediate results
|
| 81 |
+
# @chain
|
| 82 |
+
async def retrieve_documents(state,config, vectorstore,reranker,llm,rerank_by_question=True, k_final=15, k_before_reranking=100, k_summary=5):
|
| 83 |
+
print("---- Retrieve documents ----")
|
| 84 |
+
|
| 85 |
+
# Get the documents from the state
|
| 86 |
+
if "documents" in state and state["documents"] is not None:
|
| 87 |
+
docs = state["documents"]
|
| 88 |
+
else:
|
| 89 |
+
docs = []
|
| 90 |
+
# Get the related_content from the state
|
| 91 |
+
if "related_content" in state and state["related_content"] is not None:
|
| 92 |
+
related_content = state["related_content"]
|
| 93 |
+
else:
|
| 94 |
+
related_content = []
|
| 95 |
+
|
| 96 |
+
# Get the current question
|
| 97 |
+
current_question = state["remaining_questions"][0]
|
| 98 |
+
remaining_questions = state["remaining_questions"][1:]
|
| 99 |
+
|
| 100 |
+
k_by_question = k_final // state["n_questions"]
|
| 101 |
+
k_summary_by_question = _get_k_summary_by_question(state["n_questions"])
|
| 102 |
+
|
| 103 |
+
sources = current_question["sources"]
|
| 104 |
+
question = current_question["question"]
|
| 105 |
+
index = current_question["index"]
|
| 106 |
+
|
| 107 |
|
| 108 |
+
await log_event({"question":question,"sources":sources,"index":index},"log_retriever",config)
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
if index == "Vector":
|
| 112 |
+
# Search the document store using the retriever
|
| 113 |
+
# Configure high top k for further reranking step
|
| 114 |
+
retriever = ClimateQARetriever(
|
| 115 |
+
vectorstore=vectorstore,
|
| 116 |
+
sources = sources,
|
| 117 |
+
min_size = 200,
|
| 118 |
+
k_summary = k_summary_by_question,
|
| 119 |
+
k_total = k_before_reranking,
|
| 120 |
+
threshold = 0.5,
|
| 121 |
+
)
|
| 122 |
+
docs_question_dict = await retriever.ainvoke(question,config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
|
| 124 |
|
| 125 |
+
# elif index == "OpenAlex":
|
| 126 |
+
# # keyword extraction
|
| 127 |
+
# keywords_extraction = make_keywords_extraction_chain(llm)
|
| 128 |
|
| 129 |
+
# keywords = keywords_extraction.invoke(question)["keywords"]
|
| 130 |
+
# openalex_query = " AND ".join(keywords)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 131 |
|
| 132 |
+
# print(f"... OpenAlex query: {openalex_query}")
|
| 133 |
|
| 134 |
+
# retriever_openalex = OpenAlexRetriever(
|
| 135 |
+
# min_year = state.get("min_year",1960),
|
| 136 |
+
# max_year = state.get("max_year",None),
|
| 137 |
+
# k = k_before_reranking
|
| 138 |
+
# )
|
| 139 |
+
# docs_question = await retriever_openalex.ainvoke(openalex_query,config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 140 |
|
| 141 |
+
# else:
|
| 142 |
+
# raise Exception(f"Index {index} not found in the routing index")
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
# Rerank
|
| 147 |
+
if reranker is not None:
|
| 148 |
+
with suppress_output():
|
| 149 |
+
docs_question_summary_reranked = rerank_docs(reranker,docs_question_dict["docs_summaries"],question)
|
| 150 |
+
docs_question_fulltext_reranked = rerank_docs(reranker,docs_question_dict["docs_full"],question)
|
| 151 |
+
docs_question_images_reranked = rerank_docs(reranker,docs_question_dict["docs_images"],question)
|
| 152 |
+
if rerank_by_question:
|
| 153 |
+
docs_question_summary_reranked = sorted(docs_question_summary_reranked, key=lambda x: x.metadata["reranking_score"], reverse=True)
|
| 154 |
+
docs_question_fulltext_reranked = sorted(docs_question_fulltext_reranked, key=lambda x: x.metadata["reranking_score"], reverse=True)
|
| 155 |
+
docs_question_images_reranked = sorted(docs_question_images_reranked, key=lambda x: x.metadata["reranking_score"], reverse=True)
|
| 156 |
+
else:
|
| 157 |
+
docs_question = docs_question_dict["docs_summaries"] + docs_question_dict["docs_full"]
|
| 158 |
+
# Add a default reranking score
|
| 159 |
+
for doc in docs_question:
|
| 160 |
+
doc.metadata["reranking_score"] = doc.metadata["similarity_score"]
|
| 161 |
+
|
| 162 |
+
docs_question = docs_question_summary_reranked + docs_question_fulltext_reranked
|
| 163 |
+
docs_question = docs_question[:k_by_question]
|
| 164 |
+
images_question = docs_question_images_reranked[:k_by_question]
|
| 165 |
+
|
| 166 |
+
if reranker is not None and rerank_by_question:
|
| 167 |
+
docs_question = sorted(docs_question, key=lambda x: x.metadata["reranking_score"], reverse=True)
|
| 168 |
+
|
| 169 |
+
# Add sources used in the metadata
|
| 170 |
+
docs_question = _add_sources_used_in_metadata(docs_question,sources,question,index)
|
| 171 |
+
images_question = _add_sources_used_in_metadata(images_question,sources,question,index)
|
| 172 |
+
|
| 173 |
+
# Add to the list of docs
|
| 174 |
+
docs.extend(docs_question)
|
| 175 |
+
related_content.extend(images_question)
|
| 176 |
+
|
| 177 |
+
new_state = {"documents":docs, "related_contents": related_content,"remaining_questions":remaining_questions}
|
| 178 |
+
return new_state
|
| 179 |
+
|
| 180 |
|
|
|
|
|
|
|
| 181 |
|
| 182 |
+
def make_retriever_node(vectorstore,reranker,llm,rerank_by_question=True, k_final=15, k_before_reranking=100, k_summary=5):
|
| 183 |
+
@chain
|
| 184 |
+
async def retrieve_docs(state, config):
|
| 185 |
+
state = await retrieve_documents(state,config, vectorstore,reranker,llm,rerank_by_question, k_final, k_before_reranking, k_summary)
|
| 186 |
+
return state
|
| 187 |
+
|
| 188 |
+
return retrieve_docs
|
| 189 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 190 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 191 |
|
climateqa/engine/chains/retriever.py
CHANGED
|
@@ -1,126 +1,126 @@
|
|
| 1 |
-
import sys
|
| 2 |
-
import os
|
| 3 |
-
from contextlib import contextmanager
|
| 4 |
|
| 5 |
-
from ..reranker import rerank_docs
|
| 6 |
-
from ...knowledge.retriever import ClimateQARetriever
|
| 7 |
|
| 8 |
|
| 9 |
|
| 10 |
|
| 11 |
-
def divide_into_parts(target, parts):
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
|
| 27 |
-
|
| 28 |
|
| 29 |
|
| 30 |
-
@contextmanager
|
| 31 |
-
def suppress_output():
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
|
| 47 |
|
| 48 |
|
| 49 |
-
def make_retriever_node(vectorstore,reranker,rerank_by_question=True, k_final=15, k_before_reranking=100, k_summary=5):
|
| 50 |
|
| 51 |
-
|
| 52 |
|
| 53 |
-
|
| 54 |
-
|
| 55 |
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
|
| 69 |
-
|
| 70 |
|
| 71 |
-
|
| 72 |
|
| 73 |
-
|
| 74 |
-
|
| 75 |
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
|
| 114 |
-
|
| 115 |
-
|
| 116 |
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
|
| 122 |
-
|
| 123 |
-
|
| 124 |
|
| 125 |
-
|
| 126 |
|
|
|
|
| 1 |
+
# import sys
|
| 2 |
+
# import os
|
| 3 |
+
# from contextlib import contextmanager
|
| 4 |
|
| 5 |
+
# from ..reranker import rerank_docs
|
| 6 |
+
# from ...knowledge.retriever import ClimateQARetriever
|
| 7 |
|
| 8 |
|
| 9 |
|
| 10 |
|
| 11 |
+
# def divide_into_parts(target, parts):
|
| 12 |
+
# # Base value for each part
|
| 13 |
+
# base = target // parts
|
| 14 |
+
# # Remainder to distribute
|
| 15 |
+
# remainder = target % parts
|
| 16 |
+
# # List to hold the result
|
| 17 |
+
# result = []
|
| 18 |
|
| 19 |
+
# for i in range(parts):
|
| 20 |
+
# if i < remainder:
|
| 21 |
+
# # These parts get base value + 1
|
| 22 |
+
# result.append(base + 1)
|
| 23 |
+
# else:
|
| 24 |
+
# # The rest get the base value
|
| 25 |
+
# result.append(base)
|
| 26 |
|
| 27 |
+
# return result
|
| 28 |
|
| 29 |
|
| 30 |
+
# @contextmanager
|
| 31 |
+
# def suppress_output():
|
| 32 |
+
# # Open a null device
|
| 33 |
+
# with open(os.devnull, 'w') as devnull:
|
| 34 |
+
# # Store the original stdout and stderr
|
| 35 |
+
# old_stdout = sys.stdout
|
| 36 |
+
# old_stderr = sys.stderr
|
| 37 |
+
# # Redirect stdout and stderr to the null device
|
| 38 |
+
# sys.stdout = devnull
|
| 39 |
+
# sys.stderr = devnull
|
| 40 |
+
# try:
|
| 41 |
+
# yield
|
| 42 |
+
# finally:
|
| 43 |
+
# # Restore stdout and stderr
|
| 44 |
+
# sys.stdout = old_stdout
|
| 45 |
+
# sys.stderr = old_stderr
|
| 46 |
|
| 47 |
|
| 48 |
|
| 49 |
+
# def make_retriever_node(vectorstore,reranker,rerank_by_question=True, k_final=15, k_before_reranking=100, k_summary=5):
|
| 50 |
|
| 51 |
+
# def retrieve_documents(state):
|
| 52 |
|
| 53 |
+
# POSSIBLE_SOURCES = ["IPCC","IPBES","IPOS"] # ,"OpenAlex"]
|
| 54 |
+
# questions = state["questions"]
|
| 55 |
|
| 56 |
+
# # Use sources from the user input or from the LLM detection
|
| 57 |
+
# if "sources_input" not in state or state["sources_input"] is None:
|
| 58 |
+
# sources_input = ["auto"]
|
| 59 |
+
# else:
|
| 60 |
+
# sources_input = state["sources_input"]
|
| 61 |
+
# auto_mode = "auto" in sources_input
|
| 62 |
|
| 63 |
+
# # There are several options to get the final top k
|
| 64 |
+
# # Option 1 - Get 100 documents by question and rerank by question
|
| 65 |
+
# # Option 2 - Get 100/n documents by question and rerank the total
|
| 66 |
+
# if rerank_by_question:
|
| 67 |
+
# k_by_question = divide_into_parts(k_final,len(questions))
|
| 68 |
|
| 69 |
+
# docs = []
|
| 70 |
|
| 71 |
+
# for i,q in enumerate(questions):
|
| 72 |
|
| 73 |
+
# sources = q["sources"]
|
| 74 |
+
# question = q["question"]
|
| 75 |
|
| 76 |
+
# # If auto mode, we use the sources detected by the LLM
|
| 77 |
+
# if auto_mode:
|
| 78 |
+
# sources = [x for x in sources if x in POSSIBLE_SOURCES]
|
| 79 |
|
| 80 |
+
# # Otherwise, we use the config
|
| 81 |
+
# else:
|
| 82 |
+
# sources = sources_input
|
| 83 |
|
| 84 |
+
# # Search the document store using the retriever
|
| 85 |
+
# # Configure high top k for further reranking step
|
| 86 |
+
# retriever = ClimateQARetriever(
|
| 87 |
+
# vectorstore=vectorstore,
|
| 88 |
+
# sources = sources,
|
| 89 |
+
# # reports = ias_reports,
|
| 90 |
+
# min_size = 200,
|
| 91 |
+
# k_summary = k_summary,
|
| 92 |
+
# k_total = k_before_reranking,
|
| 93 |
+
# threshold = 0.5,
|
| 94 |
+
# )
|
| 95 |
+
# docs_question = retriever.get_relevant_documents(question)
|
| 96 |
|
| 97 |
+
# # Rerank
|
| 98 |
+
# if reranker is not None:
|
| 99 |
+
# with suppress_output():
|
| 100 |
+
# docs_question = rerank_docs(reranker,docs_question,question)
|
| 101 |
+
# else:
|
| 102 |
+
# # Add a default reranking score
|
| 103 |
+
# for doc in docs_question:
|
| 104 |
+
# doc.metadata["reranking_score"] = doc.metadata["similarity_score"]
|
| 105 |
|
| 106 |
+
# # If rerank by question we select the top documents for each question
|
| 107 |
+
# if rerank_by_question:
|
| 108 |
+
# docs_question = docs_question[:k_by_question[i]]
|
| 109 |
|
| 110 |
+
# # Add sources used in the metadata
|
| 111 |
+
# for doc in docs_question:
|
| 112 |
+
# doc.metadata["sources_used"] = sources
|
| 113 |
|
| 114 |
+
# # Add to the list of docs
|
| 115 |
+
# docs.extend(docs_question)
|
| 116 |
|
| 117 |
+
# # Sorting the list in descending order by rerank_score
|
| 118 |
+
# # Then select the top k
|
| 119 |
+
# docs = sorted(docs, key=lambda x: x.metadata["reranking_score"], reverse=True)
|
| 120 |
+
# docs = docs[:k_final]
|
| 121 |
|
| 122 |
+
# new_state = {"documents":docs}
|
| 123 |
+
# return new_state
|
| 124 |
|
| 125 |
+
# return retrieve_documents
|
| 126 |
|
climateqa/engine/graph.py
CHANGED
|
@@ -40,6 +40,7 @@ class GraphState(TypedDict):
|
|
| 40 |
min_year: int = 1960
|
| 41 |
max_year: int = None
|
| 42 |
documents: List[Document]
|
|
|
|
| 43 |
recommended_content : List[Document]
|
| 44 |
# graphs_returned: Dict[str,str]
|
| 45 |
|
|
|
|
| 40 |
min_year: int = 1960
|
| 41 |
max_year: int = None
|
| 42 |
documents: List[Document]
|
| 43 |
+
related_contents : Dict[str,Document]
|
| 44 |
recommended_content : List[Document]
|
| 45 |
# graphs_returned: Dict[str,str]
|
| 46 |
|
climateqa/knowledge/retriever.py
CHANGED
|
@@ -11,6 +11,18 @@ from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun
|
|
| 11 |
from typing import List
|
| 12 |
from pydantic import Field
|
| 13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
class ClimateQARetriever(BaseRetriever):
|
| 15 |
vectorstore:VectorStore
|
| 16 |
sources:list = ["IPCC","IPBES","IPOS"]
|
|
@@ -20,6 +32,7 @@ class ClimateQARetriever(BaseRetriever):
|
|
| 20 |
k_total:int = 10
|
| 21 |
namespace:str = "vectors",
|
| 22 |
min_size:int = 200,
|
|
|
|
| 23 |
|
| 24 |
|
| 25 |
def _get_relevant_documents(
|
|
@@ -43,6 +56,7 @@ class ClimateQARetriever(BaseRetriever):
|
|
| 43 |
# Search for k_summary documents in the summaries dataset
|
| 44 |
filters_summaries = {
|
| 45 |
**filters,
|
|
|
|
| 46 |
"report_type": { "$in":["SPM"]},
|
| 47 |
}
|
| 48 |
|
|
@@ -52,31 +66,36 @@ class ClimateQARetriever(BaseRetriever):
|
|
| 52 |
# Search for k_total - k_summary documents in the full reports dataset
|
| 53 |
filters_full = {
|
| 54 |
**filters,
|
|
|
|
| 55 |
"report_type": { "$nin":["SPM"]},
|
| 56 |
}
|
| 57 |
k_full = self.k_total - len(docs_summaries)
|
| 58 |
docs_full = self.vectorstore.similarity_search_with_score(query=query,filter = filters_full,k = k_full)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
|
| 60 |
# Concatenate documents
|
| 61 |
-
docs = docs_summaries + docs_full
|
| 62 |
|
| 63 |
# Filter if scores are below threshold
|
| 64 |
docs = [x for x in docs if len(x[0].page_content) > self.min_size]
|
| 65 |
# docs = [x for x in docs if x[1] > self.threshold]
|
| 66 |
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
doc.metadata["content"] = doc.page_content
|
| 73 |
-
doc.metadata["page_number"] = int(doc.metadata["page_number"]) + 1
|
| 74 |
-
# doc.page_content = f"""Doc {i+1} - {doc.metadata['short_name']}: {doc.page_content}"""
|
| 75 |
-
results.append(doc)
|
| 76 |
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
|
|
|
| 81 |
|
| 82 |
|
|
|
|
| 11 |
from typing import List
|
| 12 |
from pydantic import Field
|
| 13 |
|
| 14 |
+
def _add_metadata_and_score(docs: List) -> Document:
|
| 15 |
+
# Add score to metadata
|
| 16 |
+
docs_with_metadata = []
|
| 17 |
+
for i,(doc,score) in enumerate(docs):
|
| 18 |
+
doc.page_content = doc.page_content.replace("\r\n"," ")
|
| 19 |
+
doc.metadata["similarity_score"] = score
|
| 20 |
+
doc.metadata["content"] = doc.page_content
|
| 21 |
+
doc.metadata["page_number"] = int(doc.metadata["page_number"]) + 1
|
| 22 |
+
# doc.page_content = f"""Doc {i+1} - {doc.metadata['short_name']}: {doc.page_content}"""
|
| 23 |
+
docs_with_metadata.append(doc)
|
| 24 |
+
return docs_with_metadata
|
| 25 |
+
|
| 26 |
class ClimateQARetriever(BaseRetriever):
|
| 27 |
vectorstore:VectorStore
|
| 28 |
sources:list = ["IPCC","IPBES","IPOS"]
|
|
|
|
| 32 |
k_total:int = 10
|
| 33 |
namespace:str = "vectors",
|
| 34 |
min_size:int = 200,
|
| 35 |
+
|
| 36 |
|
| 37 |
|
| 38 |
def _get_relevant_documents(
|
|
|
|
| 56 |
# Search for k_summary documents in the summaries dataset
|
| 57 |
filters_summaries = {
|
| 58 |
**filters,
|
| 59 |
+
"chunk_type":"text",
|
| 60 |
"report_type": { "$in":["SPM"]},
|
| 61 |
}
|
| 62 |
|
|
|
|
| 66 |
# Search for k_total - k_summary documents in the full reports dataset
|
| 67 |
filters_full = {
|
| 68 |
**filters,
|
| 69 |
+
"chunk_type":"text",
|
| 70 |
"report_type": { "$nin":["SPM"]},
|
| 71 |
}
|
| 72 |
k_full = self.k_total - len(docs_summaries)
|
| 73 |
docs_full = self.vectorstore.similarity_search_with_score(query=query,filter = filters_full,k = k_full)
|
| 74 |
+
|
| 75 |
+
# Images
|
| 76 |
+
filters_image = {
|
| 77 |
+
**filters,
|
| 78 |
+
"chunk_type":"image"
|
| 79 |
+
}
|
| 80 |
+
docs_images = self.vectorstore.similarity_search_with_score(query=query,filter = filters_image,k = k_full)
|
| 81 |
|
| 82 |
# Concatenate documents
|
| 83 |
+
docs = docs_summaries + docs_full + docs_images
|
| 84 |
|
| 85 |
# Filter if scores are below threshold
|
| 86 |
docs = [x for x in docs if len(x[0].page_content) > self.min_size]
|
| 87 |
# docs = [x for x in docs if x[1] > self.threshold]
|
| 88 |
|
| 89 |
+
docs_summaries, docs_full, docs_images = _add_metadata_and_score(docs_summaries), _add_metadata_and_score(docs_full), _add_metadata_and_score(docs_images)
|
| 90 |
+
|
| 91 |
+
# Filter if length are below threshold
|
| 92 |
+
docs_summaries = [x for x in docs_summaries if len(x.page_content) > self.min_size]
|
| 93 |
+
docs_full = [x for x in docs_full if len(x.page_content) > self.min_size]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
|
| 95 |
+
return {
|
| 96 |
+
"docs_summaries" : docs_summaries,
|
| 97 |
+
"docs_full" : docs_full,
|
| 98 |
+
"docs_images" : docs_images
|
| 99 |
+
}
|
| 100 |
|
| 101 |
|
sandbox/20241104 - CQA - StepByStep CQA.ipynb
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|