fix figures retrieval
Browse files- app.py +38 -16
- climateqa/engine/chains/retrieve_documents.py +6 -3
- climateqa/engine/graph.py +2 -2
- front/utils.py +20 -16
app.py
CHANGED
|
@@ -113,7 +113,7 @@ vectorstore = get_pinecone_vectorstore(embeddings_function, index_name = os.gete
|
|
| 113 |
vectorstore_graphs = get_pinecone_vectorstore(embeddings_function, index_name = os.getenv("PINECONE_API_INDEX_OWID"), text_key="description")
|
| 114 |
|
| 115 |
llm = get_llm(provider="openai",max_tokens = 1024,temperature = 0.0)
|
| 116 |
-
reranker = get_reranker("
|
| 117 |
|
| 118 |
agent = make_graph_agent(llm=llm, vectorstore_ipcc=vectorstore, vectorstore_graphs=vectorstore_graphs, reranker=reranker)
|
| 119 |
|
|
@@ -142,7 +142,6 @@ async def chat(query, history, audience, sources, reports, relevant_content_sour
|
|
| 142 |
|
| 143 |
|
| 144 |
docs = []
|
| 145 |
-
used_figures=[]
|
| 146 |
related_contents = []
|
| 147 |
docs_html = ""
|
| 148 |
output_query = ""
|
|
@@ -165,7 +164,7 @@ async def chat(query, history, audience, sources, reports, relevant_content_sour
|
|
| 165 |
if "langgraph_node" in event["metadata"]:
|
| 166 |
node = event["metadata"]["langgraph_node"]
|
| 167 |
|
| 168 |
-
if event["event"] == "on_chain_end" and event["name"] == "retrieve_documents" :# when documents are retrieved
|
| 169 |
docs, docs_html, history, used_documents, related_contents = handle_retrieved_documents(event, history, used_documents)
|
| 170 |
|
| 171 |
elif event["event"] == "on_chain_end" and node == "categorize_intent" and event["name"] == "_write": # when the query is transformed
|
|
@@ -321,10 +320,19 @@ with gr.Blocks(title="Climate Q&A", css_paths=os.getcwd()+ "/style.css", theme=t
|
|
| 321 |
|
| 322 |
|
| 323 |
with gr.Row(elem_id = "input-message"):
|
| 324 |
-
textbox=gr.Textbox(
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 328 |
|
| 329 |
|
| 330 |
|
|
@@ -417,7 +425,9 @@ with gr.Blocks(title="Climate Q&A", css_paths=os.getcwd()+ "/style.css", theme=t
|
|
| 417 |
with gr.Tabs(elem_id = "group-subtabs") as tabs_recommended_content:
|
| 418 |
|
| 419 |
with gr.Tab("Figures",elem_id = "tab-figures",id = 3) as tab_figures:
|
| 420 |
-
sources_raw = gr.State()
|
|
|
|
|
|
|
| 421 |
|
| 422 |
with Modal(visible=False, elem_id="modal_figure_galery") as figure_modal:
|
| 423 |
gallery_component = gr.Gallery(object_fit='scale-down',elem_id="gallery-component", height="80vh")
|
|
@@ -475,9 +485,9 @@ with gr.Blocks(title="Climate Q&A", css_paths=os.getcwd()+ "/style.css", theme=t
|
|
| 475 |
)
|
| 476 |
|
| 477 |
dropdown_external_sources = gr.CheckboxGroup(
|
| 478 |
-
["IPCC
|
| 479 |
label="Select database to search for relevant content",
|
| 480 |
-
value=["IPCC
|
| 481 |
interactive=True,
|
| 482 |
)
|
| 483 |
|
|
@@ -633,15 +643,25 @@ with gr.Blocks(title="Climate Q&A", css_paths=os.getcwd()+ "/style.css", theme=t
|
|
| 633 |
return gr.update(label = recommended_content_notif_label), gr.update(label = sources_notif_label), gr.update(label = figures_notif_label), gr.update(label = graphs_notif_label), gr.update(label = papers_notif_label)
|
| 634 |
|
| 635 |
(textbox
|
| 636 |
-
.submit(start_chat, [textbox,chatbot, search_only],
|
| 637 |
-
|
| 638 |
-
|
| 639 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 640 |
)
|
| 641 |
|
|
|
|
|
|
|
| 642 |
(examples_hidden
|
| 643 |
.change(start_chat, [examples_hidden,chatbot, search_only], [textbox,tabs,chatbot],queue = False,api_name = "start_chat_examples")
|
| 644 |
-
.then(chat, [examples_hidden,chatbot,dropdown_audience, dropdown_sources,dropdown_reports, dropdown_external_sources, search_only] ,[chatbot,sources_textbox,output_query,output_language,
|
| 645 |
.then(finish_chat, None, [textbox],api_name = "finish_chat_examples")
|
| 646 |
# .then(update_sources_number_display, [sources_textbox, figures_cards, current_graphs,papers_html],[tab_sources, tab_figures, tab_graphs, tab_papers] )
|
| 647 |
)
|
|
@@ -654,7 +674,7 @@ with gr.Blocks(title="Climate Q&A", css_paths=os.getcwd()+ "/style.css", theme=t
|
|
| 654 |
return [gr.update(visible=visible_bools[i]) for i in range(len(samples))]
|
| 655 |
|
| 656 |
|
| 657 |
-
|
| 658 |
|
| 659 |
# update sources numbers
|
| 660 |
sources_textbox.change(update_sources_number_display, [sources_textbox, figures_cards, current_graphs,papers_html],[tab_recommended_content, tab_sources, tab_figures, tab_graphs, tab_papers])
|
|
@@ -674,4 +694,6 @@ with gr.Blocks(title="Climate Q&A", css_paths=os.getcwd()+ "/style.css", theme=t
|
|
| 674 |
|
| 675 |
demo.queue()
|
| 676 |
|
|
|
|
|
|
|
| 677 |
demo.launch(ssr_mode=False)
|
|
|
|
| 113 |
vectorstore_graphs = get_pinecone_vectorstore(embeddings_function, index_name = os.getenv("PINECONE_API_INDEX_OWID"), text_key="description")
|
| 114 |
|
| 115 |
llm = get_llm(provider="openai",max_tokens = 1024,temperature = 0.0)
|
| 116 |
+
reranker = get_reranker("large")
|
| 117 |
|
| 118 |
agent = make_graph_agent(llm=llm, vectorstore_ipcc=vectorstore, vectorstore_graphs=vectorstore_graphs, reranker=reranker)
|
| 119 |
|
|
|
|
| 142 |
|
| 143 |
|
| 144 |
docs = []
|
|
|
|
| 145 |
related_contents = []
|
| 146 |
docs_html = ""
|
| 147 |
output_query = ""
|
|
|
|
| 164 |
if "langgraph_node" in event["metadata"]:
|
| 165 |
node = event["metadata"]["langgraph_node"]
|
| 166 |
|
| 167 |
+
if event["event"] == "on_chain_end" and event["name"] == "retrieve_documents" and event["data"]["output"] != None:# when documents are retrieved
|
| 168 |
docs, docs_html, history, used_documents, related_contents = handle_retrieved_documents(event, history, used_documents)
|
| 169 |
|
| 170 |
elif event["event"] == "on_chain_end" and node == "categorize_intent" and event["name"] == "_write": # when the query is transformed
|
|
|
|
| 320 |
|
| 321 |
|
| 322 |
with gr.Row(elem_id = "input-message"):
|
| 323 |
+
textbox = gr.Textbox(
|
| 324 |
+
placeholder="Ask me anything here!",
|
| 325 |
+
show_label=False,
|
| 326 |
+
scale=12,
|
| 327 |
+
lines=1,
|
| 328 |
+
interactive=True,
|
| 329 |
+
elem_id="input-textbox"
|
| 330 |
+
)
|
| 331 |
+
|
| 332 |
+
config_button = gr.Button(
|
| 333 |
+
"",
|
| 334 |
+
elem_id="config-button"
|
| 335 |
+
)
|
| 336 |
|
| 337 |
|
| 338 |
|
|
|
|
| 425 |
with gr.Tabs(elem_id = "group-subtabs") as tabs_recommended_content:
|
| 426 |
|
| 427 |
with gr.Tab("Figures",elem_id = "tab-figures",id = 3) as tab_figures:
|
| 428 |
+
sources_raw = gr.State([])
|
| 429 |
+
new_figures = gr.State([])
|
| 430 |
+
used_figures = gr.State([])
|
| 431 |
|
| 432 |
with Modal(visible=False, elem_id="modal_figure_galery") as figure_modal:
|
| 433 |
gallery_component = gr.Gallery(object_fit='scale-down',elem_id="gallery-component", height="80vh")
|
|
|
|
| 485 |
)
|
| 486 |
|
| 487 |
dropdown_external_sources = gr.CheckboxGroup(
|
| 488 |
+
["Figures (IPCC/IPBES)","Papers (OpenAlex)", "Graphs (OurWorldInData)"],
|
| 489 |
label="Select database to search for relevant content",
|
| 490 |
+
value=["Figures (IPCC/IPBES)"],
|
| 491 |
interactive=True,
|
| 492 |
)
|
| 493 |
|
|
|
|
| 643 |
return gr.update(label = recommended_content_notif_label), gr.update(label = sources_notif_label), gr.update(label = figures_notif_label), gr.update(label = graphs_notif_label), gr.update(label = papers_notif_label)
|
| 644 |
|
| 645 |
(textbox
|
| 646 |
+
.submit(start_chat, [textbox, chatbot, search_only],
|
| 647 |
+
[textbox, tabs, chatbot],
|
| 648 |
+
queue=False,
|
| 649 |
+
api_name="start_chat_textbox")
|
| 650 |
+
.then(chat, [textbox, chatbot, dropdown_audience, dropdown_sources,
|
| 651 |
+
dropdown_reports, dropdown_external_sources, search_only],
|
| 652 |
+
[chatbot, sources_textbox, output_query, output_language,
|
| 653 |
+
new_figures, current_graphs],
|
| 654 |
+
concurrency_limit=8,
|
| 655 |
+
api_name="chat_textbox")
|
| 656 |
+
.then(finish_chat, None, [textbox],
|
| 657 |
+
api_name="finish_chat_textbox")
|
| 658 |
)
|
| 659 |
|
| 660 |
+
|
| 661 |
+
|
| 662 |
(examples_hidden
|
| 663 |
.change(start_chat, [examples_hidden,chatbot, search_only], [textbox,tabs,chatbot],queue = False,api_name = "start_chat_examples")
|
| 664 |
+
.then(chat, [examples_hidden,chatbot,dropdown_audience, dropdown_sources,dropdown_reports, dropdown_external_sources, search_only] ,[chatbot,sources_textbox,output_query,output_language, new_figures, current_graphs],concurrency_limit = 8,api_name = "chat_textbox")
|
| 665 |
.then(finish_chat, None, [textbox],api_name = "finish_chat_examples")
|
| 666 |
# .then(update_sources_number_display, [sources_textbox, figures_cards, current_graphs,papers_html],[tab_sources, tab_figures, tab_graphs, tab_papers] )
|
| 667 |
)
|
|
|
|
| 674 |
return [gr.update(visible=visible_bools[i]) for i in range(len(samples))]
|
| 675 |
|
| 676 |
|
| 677 |
+
new_figures.change(process_figures, inputs=[sources_raw, new_figures], outputs=[sources_raw, figures_cards, gallery_component])
|
| 678 |
|
| 679 |
# update sources numbers
|
| 680 |
sources_textbox.change(update_sources_number_display, [sources_textbox, figures_cards, current_graphs,papers_html],[tab_recommended_content, tab_sources, tab_figures, tab_graphs, tab_papers])
|
|
|
|
| 694 |
|
| 695 |
demo.queue()
|
| 696 |
|
| 697 |
+
|
| 698 |
+
|
| 699 |
demo.launch(ssr_mode=False)
|
climateqa/engine/chains/retrieve_documents.py
CHANGED
|
@@ -87,7 +87,7 @@ def _get_k_images_by_question(n_questions):
|
|
| 87 |
elif n_questions == 2:
|
| 88 |
return 5
|
| 89 |
elif n_questions == 3:
|
| 90 |
-
return
|
| 91 |
else:
|
| 92 |
return 1
|
| 93 |
|
|
@@ -98,7 +98,10 @@ def _add_metadata_and_score(docs: List) -> Document:
|
|
| 98 |
doc.page_content = doc.page_content.replace("\r\n"," ")
|
| 99 |
doc.metadata["similarity_score"] = score
|
| 100 |
doc.metadata["content"] = doc.page_content
|
| 101 |
-
doc.metadata["page_number"]
|
|
|
|
|
|
|
|
|
|
| 102 |
# doc.page_content = f"""Doc {i+1} - {doc.metadata['short_name']}: {doc.page_content}"""
|
| 103 |
docs_with_metadata.append(doc)
|
| 104 |
return docs_with_metadata
|
|
@@ -222,7 +225,7 @@ async def retrieve_documents(state,config, vectorstore,reranker,llm,rerank_by_qu
|
|
| 222 |
else:
|
| 223 |
related_content = []
|
| 224 |
|
| 225 |
-
search_figures = "IPCC
|
| 226 |
search_only = state["search_only"]
|
| 227 |
|
| 228 |
# Get the current question
|
|
|
|
| 87 |
elif n_questions == 2:
|
| 88 |
return 5
|
| 89 |
elif n_questions == 3:
|
| 90 |
+
return 3
|
| 91 |
else:
|
| 92 |
return 1
|
| 93 |
|
|
|
|
| 98 |
doc.page_content = doc.page_content.replace("\r\n"," ")
|
| 99 |
doc.metadata["similarity_score"] = score
|
| 100 |
doc.metadata["content"] = doc.page_content
|
| 101 |
+
if doc.metadata["page_number"] != "N/A":
|
| 102 |
+
doc.metadata["page_number"] = int(doc.metadata["page_number"]) + 1
|
| 103 |
+
else:
|
| 104 |
+
doc.metadata["page_number"] = 1
|
| 105 |
# doc.page_content = f"""Doc {i+1} - {doc.metadata['short_name']}: {doc.page_content}"""
|
| 106 |
docs_with_metadata.append(doc)
|
| 107 |
return docs_with_metadata
|
|
|
|
| 225 |
else:
|
| 226 |
related_content = []
|
| 227 |
|
| 228 |
+
search_figures = "Figures (IPCC/IPBES)" in state["relevant_content_sources"]
|
| 229 |
search_only = state["search_only"]
|
| 230 |
|
| 231 |
# Get the current question
|
climateqa/engine/graph.py
CHANGED
|
@@ -36,7 +36,7 @@ class GraphState(TypedDict):
|
|
| 36 |
answer: str
|
| 37 |
audience: str = "experts"
|
| 38 |
sources_input: List[str] = ["IPCC","IPBES"]
|
| 39 |
-
relevant_content_sources: List[str] = ["IPCC
|
| 40 |
sources_auto: bool = True
|
| 41 |
min_year: int = 1960
|
| 42 |
max_year: int = None
|
|
@@ -82,7 +82,7 @@ def route_based_on_relevant_docs(state,threshold_docs=0.2):
|
|
| 82 |
return "answer_rag_no_docs"
|
| 83 |
|
| 84 |
def route_retrieve_documents(state):
|
| 85 |
-
if state["search_only"] :
|
| 86 |
return END
|
| 87 |
elif len(state["remaining_questions"]) > 0:
|
| 88 |
return "retrieve_documents"
|
|
|
|
| 36 |
answer: str
|
| 37 |
audience: str = "experts"
|
| 38 |
sources_input: List[str] = ["IPCC","IPBES"]
|
| 39 |
+
relevant_content_sources: List[str] = ["Figures (IPCC/IPBES)"]
|
| 40 |
sources_auto: bool = True
|
| 41 |
min_year: int = 1960
|
| 42 |
max_year: int = None
|
|
|
|
| 82 |
return "answer_rag_no_docs"
|
| 83 |
|
| 84 |
def route_retrieve_documents(state):
|
| 85 |
+
if len(state["remaining_questions"]) == 0 and state["search_only"] :
|
| 86 |
return END
|
| 87 |
elif len(state["remaining_questions"]) > 0:
|
| 88 |
return "retrieve_documents"
|
front/utils.py
CHANGED
|
@@ -39,25 +39,29 @@ def parse_output_llm_with_sources(output:str)->str:
|
|
| 39 |
content_parts = "".join(parts)
|
| 40 |
return content_parts
|
| 41 |
|
| 42 |
-
def process_figures(docs:list)->tuple:
|
| 43 |
-
|
| 44 |
-
|
| 45 |
figures = '<div class="figures-container"><p></p> </div>'
|
|
|
|
|
|
|
|
|
|
| 46 |
if docs == []:
|
| 47 |
-
return figures, gallery
|
|
|
|
|
|
|
| 48 |
docs_figures = [d for d in docs if d.metadata["chunk_type"] == "image"]
|
| 49 |
-
for
|
| 50 |
-
if doc.metadata["chunk_type"] == "image":
|
| 51 |
-
|
| 52 |
-
title = f"{doc.metadata['figure_code']} - {doc.metadata['short_name']}"
|
| 53 |
-
else:
|
| 54 |
-
title = f"{doc.metadata['short_name']}"
|
| 55 |
|
| 56 |
|
| 57 |
-
if
|
| 58 |
-
used_figures.append(
|
|
|
|
|
|
|
| 59 |
try:
|
| 60 |
-
key = f"Image {
|
| 61 |
|
| 62 |
image_path = doc.metadata["image_path"].split("documents/")[1]
|
| 63 |
img = get_image_from_azure_blob_storage(image_path)
|
|
@@ -70,12 +74,12 @@ def process_figures(docs:list)->tuple:
|
|
| 70 |
|
| 71 |
img_str = base64.b64encode(buffered.getvalue()).decode()
|
| 72 |
|
| 73 |
-
figures = figures + make_html_figure_sources(doc,
|
| 74 |
gallery.append(img)
|
| 75 |
except Exception as e:
|
| 76 |
-
print(f"Skipped adding image {
|
| 77 |
|
| 78 |
-
return figures, gallery
|
| 79 |
|
| 80 |
|
| 81 |
def generate_html_graphs(graphs:list)->str:
|
|
|
|
| 39 |
content_parts = "".join(parts)
|
| 40 |
return content_parts
|
| 41 |
|
| 42 |
+
def process_figures(docs:list, new_figures:list)->tuple:
|
| 43 |
+
docs = docs + new_figures
|
| 44 |
+
|
| 45 |
figures = '<div class="figures-container"><p></p> </div>'
|
| 46 |
+
gallery = []
|
| 47 |
+
used_figures = []
|
| 48 |
+
|
| 49 |
if docs == []:
|
| 50 |
+
return figures, gallery, used_figures
|
| 51 |
+
|
| 52 |
+
|
| 53 |
docs_figures = [d for d in docs if d.metadata["chunk_type"] == "image"]
|
| 54 |
+
for i_doc, doc in enumerate(docs_figures):
|
| 55 |
+
if doc.metadata["chunk_type"] == "image":
|
| 56 |
+
path = doc.metadata["image_path"]
|
|
|
|
|
|
|
|
|
|
| 57 |
|
| 58 |
|
| 59 |
+
if path not in used_figures:
|
| 60 |
+
used_figures.append(path)
|
| 61 |
+
figure_number = len(used_figures)
|
| 62 |
+
|
| 63 |
try:
|
| 64 |
+
key = f"Image {figure_number}"
|
| 65 |
|
| 66 |
image_path = doc.metadata["image_path"].split("documents/")[1]
|
| 67 |
img = get_image_from_azure_blob_storage(image_path)
|
|
|
|
| 74 |
|
| 75 |
img_str = base64.b64encode(buffered.getvalue()).decode()
|
| 76 |
|
| 77 |
+
figures = figures + make_html_figure_sources(doc, figure_number, img_str)
|
| 78 |
gallery.append(img)
|
| 79 |
except Exception as e:
|
| 80 |
+
print(f"Skipped adding image {figure_number} because of {e}")
|
| 81 |
|
| 82 |
+
return docs, figures, gallery
|
| 83 |
|
| 84 |
|
| 85 |
def generate_html_graphs(graphs:list)->str:
|