move code from papers in separate file
Browse files- app.py +8 -80
- climateqa/engine/chains/retrieve_papers.py +95 -0
- climateqa/engine/keywords.py +3 -1
- climateqa/knowledge/openalex.py +4 -0
app.py
CHANGED
|
@@ -1,11 +1,9 @@
|
|
| 1 |
from climateqa.engine.embeddings import get_embeddings_function
|
| 2 |
embeddings_function = get_embeddings_function()
|
| 3 |
|
| 4 |
-
from climateqa.knowledge.openalex import OpenAlex
|
| 5 |
from sentence_transformers import CrossEncoder
|
| 6 |
|
| 7 |
# reranker = CrossEncoder("mixedbread-ai/mxbai-rerank-xsmall-v1")
|
| 8 |
-
oa = OpenAlex()
|
| 9 |
|
| 10 |
import gradio as gr
|
| 11 |
from gradio_modal import Modal
|
|
@@ -44,10 +42,9 @@ from climateqa.engine.chains.prompts import audience_prompts
|
|
| 44 |
from climateqa.sample_questions import QUESTIONS
|
| 45 |
from climateqa.constants import POSSIBLE_REPORTS, OWID_CATEGORIES
|
| 46 |
from climateqa.utils import get_image_from_azure_blob_storage
|
| 47 |
-
from climateqa.engine.keywords import make_keywords_chain
|
| 48 |
-
from climateqa.engine.chains.answer_rag import make_rag_papers_chain
|
| 49 |
from climateqa.engine.graph import make_graph_agent
|
| 50 |
from climateqa.engine.embeddings import get_embeddings_function
|
|
|
|
| 51 |
|
| 52 |
from front.utils import serialize_docs,process_figures,make_html_df
|
| 53 |
|
|
@@ -249,84 +246,9 @@ def log_on_azure(file, logs, share_client):
|
|
| 249 |
file_client.upload_file(logs)
|
| 250 |
|
| 251 |
|
| 252 |
-
def generate_keywords(query):
|
| 253 |
-
chain = make_keywords_chain(llm)
|
| 254 |
-
keywords = chain.invoke(query)
|
| 255 |
-
keywords = " AND ".join(keywords["keywords"])
|
| 256 |
-
return keywords
|
| 257 |
|
| 258 |
|
| 259 |
|
| 260 |
-
papers_cols_widths = {
|
| 261 |
-
"id":100,
|
| 262 |
-
"title":300,
|
| 263 |
-
"doi":100,
|
| 264 |
-
"publication_year":100,
|
| 265 |
-
"abstract":500,
|
| 266 |
-
"is_oa":50,
|
| 267 |
-
}
|
| 268 |
-
|
| 269 |
-
papers_cols = list(papers_cols_widths.keys())
|
| 270 |
-
papers_cols_widths = list(papers_cols_widths.values())
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
async def find_papers(query,after, relevant_content_sources):
|
| 274 |
-
if "OpenAlex" in relevant_content_sources:
|
| 275 |
-
summary = ""
|
| 276 |
-
keywords = generate_keywords(query)
|
| 277 |
-
df_works = oa.search(keywords,after = after)
|
| 278 |
-
df_works = df_works.dropna(subset=["abstract"])
|
| 279 |
-
df_works = oa.rerank(query,df_works,reranker)
|
| 280 |
-
df_works = df_works.sort_values("rerank_score",ascending=False)
|
| 281 |
-
docs_html = []
|
| 282 |
-
for i in range(10):
|
| 283 |
-
docs_html.append(make_html_df(df_works, i))
|
| 284 |
-
docs_html = "".join(docs_html)
|
| 285 |
-
print(docs_html)
|
| 286 |
-
G = oa.make_network(df_works)
|
| 287 |
-
|
| 288 |
-
height = "750px"
|
| 289 |
-
network = oa.show_network(G,color_by = "rerank_score",notebook=False,height = height)
|
| 290 |
-
network_html = network.generate_html()
|
| 291 |
-
|
| 292 |
-
network_html = network_html.replace("'", "\"")
|
| 293 |
-
css_to_inject = "<style>#mynetwork { border: none !important; } .card { border: none !important; }</style>"
|
| 294 |
-
network_html = network_html + css_to_inject
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
network_html = f"""<iframe style="width: 100%; height: {height};margin:0 auto" name="result" allow="midi; geolocation; microphone; camera;
|
| 298 |
-
display-capture; encrypted-media;" sandbox="allow-modals allow-forms
|
| 299 |
-
allow-scripts allow-same-origin allow-popups
|
| 300 |
-
allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
|
| 301 |
-
allowpaymentrequest="" frameborder="0" srcdoc='{network_html}'></iframe>"""
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
docs = df_works["content"].head(10).tolist()
|
| 305 |
-
|
| 306 |
-
df_works = df_works.reset_index(drop = True).reset_index().rename(columns = {"index":"doc"})
|
| 307 |
-
df_works["doc"] = df_works["doc"] + 1
|
| 308 |
-
df_works = df_works[papers_cols]
|
| 309 |
-
|
| 310 |
-
yield docs_html, network_html, summary
|
| 311 |
-
|
| 312 |
-
chain = make_rag_papers_chain(llm)
|
| 313 |
-
result = chain.astream_log({"question": query,"docs": docs,"language":"English"})
|
| 314 |
-
path_answer = "/logs/StrOutputParser/streamed_output/-"
|
| 315 |
-
|
| 316 |
-
async for op in result:
|
| 317 |
-
|
| 318 |
-
op = op.ops[0]
|
| 319 |
-
|
| 320 |
-
if op['path'] == path_answer: # reforulated question
|
| 321 |
-
new_token = op['value'] # str
|
| 322 |
-
summary += new_token
|
| 323 |
-
else:
|
| 324 |
-
continue
|
| 325 |
-
yield docs_html, network_html, summary
|
| 326 |
-
else :
|
| 327 |
-
yield "","", ""
|
| 328 |
-
|
| 329 |
-
|
| 330 |
# --------------------------------------------------------------------
|
| 331 |
# Gradio
|
| 332 |
# --------------------------------------------------------------------
|
|
@@ -430,7 +352,10 @@ with gr.Blocks(title="Climate Q&A", css_paths=os.getcwd()+ "/style.css", theme=t
|
|
| 430 |
with gr.Tab("Configuration", id = 10, ) as tab_config:
|
| 431 |
gr.Markdown("Reminders: You can talk in any language, ClimateQ&A is multi-lingual!")
|
| 432 |
|
|
|
|
|
|
|
| 433 |
with gr.Row():
|
|
|
|
| 434 |
dropdown_sources = gr.CheckboxGroup(
|
| 435 |
["IPCC", "IPBES","IPOS"],
|
| 436 |
label="Select source",
|
|
@@ -443,7 +368,7 @@ with gr.Blocks(title="Climate Q&A", css_paths=os.getcwd()+ "/style.css", theme=t
|
|
| 443 |
value=["IPCC figures"],
|
| 444 |
interactive=True,
|
| 445 |
)
|
| 446 |
-
|
| 447 |
dropdown_reports = gr.Dropdown(
|
| 448 |
POSSIBLE_REPORTS,
|
| 449 |
label="Or select specific reports",
|
|
@@ -452,6 +377,9 @@ with gr.Blocks(title="Climate Q&A", css_paths=os.getcwd()+ "/style.css", theme=t
|
|
| 452 |
interactive=True,
|
| 453 |
)
|
| 454 |
|
|
|
|
|
|
|
|
|
|
| 455 |
dropdown_audience = gr.Dropdown(
|
| 456 |
["Children","General public","Experts"],
|
| 457 |
label="Select audience",
|
|
|
|
| 1 |
from climateqa.engine.embeddings import get_embeddings_function
|
| 2 |
embeddings_function = get_embeddings_function()
|
| 3 |
|
|
|
|
| 4 |
from sentence_transformers import CrossEncoder
|
| 5 |
|
| 6 |
# reranker = CrossEncoder("mixedbread-ai/mxbai-rerank-xsmall-v1")
|
|
|
|
| 7 |
|
| 8 |
import gradio as gr
|
| 9 |
from gradio_modal import Modal
|
|
|
|
| 42 |
from climateqa.sample_questions import QUESTIONS
|
| 43 |
from climateqa.constants import POSSIBLE_REPORTS, OWID_CATEGORIES
|
| 44 |
from climateqa.utils import get_image_from_azure_blob_storage
|
|
|
|
|
|
|
| 45 |
from climateqa.engine.graph import make_graph_agent
|
| 46 |
from climateqa.engine.embeddings import get_embeddings_function
|
| 47 |
+
from climateqa.engine.chains.retrieve_papers import find_papers
|
| 48 |
|
| 49 |
from front.utils import serialize_docs,process_figures,make_html_df
|
| 50 |
|
|
|
|
| 246 |
file_client.upload_file(logs)
|
| 247 |
|
| 248 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 249 |
|
| 250 |
|
| 251 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 252 |
# --------------------------------------------------------------------
|
| 253 |
# Gradio
|
| 254 |
# --------------------------------------------------------------------
|
|
|
|
| 352 |
with gr.Tab("Configuration", id = 10, ) as tab_config:
|
| 353 |
gr.Markdown("Reminders: You can talk in any language, ClimateQ&A is multi-lingual!")
|
| 354 |
|
| 355 |
+
|
| 356 |
+
|
| 357 |
with gr.Row():
|
| 358 |
+
|
| 359 |
dropdown_sources = gr.CheckboxGroup(
|
| 360 |
["IPCC", "IPBES","IPOS"],
|
| 361 |
label="Select source",
|
|
|
|
| 368 |
value=["IPCC figures"],
|
| 369 |
interactive=True,
|
| 370 |
)
|
| 371 |
+
|
| 372 |
dropdown_reports = gr.Dropdown(
|
| 373 |
POSSIBLE_REPORTS,
|
| 374 |
label="Or select specific reports",
|
|
|
|
| 377 |
interactive=True,
|
| 378 |
)
|
| 379 |
|
| 380 |
+
search_only = gr.Checkbox(label="Search only without chating", value=False, interactive=True, elem_id="checkbox-chat")
|
| 381 |
+
|
| 382 |
+
|
| 383 |
dropdown_audience = gr.Dropdown(
|
| 384 |
["Children","General public","Experts"],
|
| 385 |
label="Select audience",
|
climateqa/engine/chains/retrieve_papers.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from climateqa.engine.keywords import make_keywords_chain
|
| 2 |
+
from climateqa.engine.llm import get_llm
|
| 3 |
+
from climateqa.knowledge.openalex import OpenAlex
|
| 4 |
+
from climateqa.engine.chains.answer_rag import make_rag_papers_chain
|
| 5 |
+
from front.utils import make_html_df
|
| 6 |
+
from climateqa.engine.reranker import get_reranker
|
| 7 |
+
|
| 8 |
+
oa = OpenAlex()
|
| 9 |
+
|
| 10 |
+
llm = get_llm(provider="openai",max_tokens = 1024,temperature = 0.0)
|
| 11 |
+
reranker = get_reranker("nano")
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
papers_cols_widths = {
|
| 15 |
+
"id":100,
|
| 16 |
+
"title":300,
|
| 17 |
+
"doi":100,
|
| 18 |
+
"publication_year":100,
|
| 19 |
+
"abstract":500,
|
| 20 |
+
"is_oa":50,
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
papers_cols = list(papers_cols_widths.keys())
|
| 24 |
+
papers_cols_widths = list(papers_cols_widths.values())
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def generate_keywords(query):
|
| 29 |
+
chain = make_keywords_chain(llm)
|
| 30 |
+
keywords = chain.invoke(query)
|
| 31 |
+
keywords = " AND ".join(keywords["keywords"])
|
| 32 |
+
return keywords
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
async def find_papers(query,after, relevant_content_sources, reranker= reranker):
|
| 36 |
+
if "OpenAlex" in relevant_content_sources:
|
| 37 |
+
summary = ""
|
| 38 |
+
keywords = generate_keywords(query)
|
| 39 |
+
df_works = oa.search(keywords,after = after)
|
| 40 |
+
|
| 41 |
+
print(f"Found {len(df_works)} papers")
|
| 42 |
+
|
| 43 |
+
if not df_works.empty:
|
| 44 |
+
df_works = df_works.dropna(subset=["abstract"])
|
| 45 |
+
df_works = df_works[df_works["abstract"] != ""].reset_index(drop = True)
|
| 46 |
+
df_works = oa.rerank(query,df_works,reranker)
|
| 47 |
+
df_works = df_works.sort_values("rerank_score",ascending=False)
|
| 48 |
+
docs_html = []
|
| 49 |
+
for i in range(10):
|
| 50 |
+
docs_html.append(make_html_df(df_works, i))
|
| 51 |
+
docs_html = "".join(docs_html)
|
| 52 |
+
G = oa.make_network(df_works)
|
| 53 |
+
|
| 54 |
+
height = "750px"
|
| 55 |
+
network = oa.show_network(G,color_by = "rerank_score",notebook=False,height = height)
|
| 56 |
+
network_html = network.generate_html()
|
| 57 |
+
|
| 58 |
+
network_html = network_html.replace("'", "\"")
|
| 59 |
+
css_to_inject = "<style>#mynetwork { border: none !important; } .card { border: none !important; }</style>"
|
| 60 |
+
network_html = network_html + css_to_inject
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
network_html = f"""<iframe style="width: 100%; height: {height};margin:0 auto" name="result" allow="midi; geolocation; microphone; camera;
|
| 64 |
+
display-capture; encrypted-media;" sandbox="allow-modals allow-forms
|
| 65 |
+
allow-scripts allow-same-origin allow-popups
|
| 66 |
+
allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
|
| 67 |
+
allowpaymentrequest="" frameborder="0" srcdoc='{network_html}'></iframe>"""
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
docs = df_works["content"].head(10).tolist()
|
| 71 |
+
|
| 72 |
+
df_works = df_works.reset_index(drop = True).reset_index().rename(columns = {"index":"doc"})
|
| 73 |
+
df_works["doc"] = df_works["doc"] + 1
|
| 74 |
+
df_works = df_works[papers_cols]
|
| 75 |
+
|
| 76 |
+
yield docs_html, network_html, summary
|
| 77 |
+
|
| 78 |
+
chain = make_rag_papers_chain(llm)
|
| 79 |
+
result = chain.astream_log({"question": query,"docs": docs,"language":"English"})
|
| 80 |
+
path_answer = "/logs/StrOutputParser/streamed_output/-"
|
| 81 |
+
|
| 82 |
+
async for op in result:
|
| 83 |
+
|
| 84 |
+
op = op.ops[0]
|
| 85 |
+
|
| 86 |
+
if op['path'] == path_answer: # reforulated question
|
| 87 |
+
new_token = op['value'] # str
|
| 88 |
+
summary += new_token
|
| 89 |
+
else:
|
| 90 |
+
continue
|
| 91 |
+
yield docs_html, network_html, summary
|
| 92 |
+
else :
|
| 93 |
+
print("No papers found")
|
| 94 |
+
else :
|
| 95 |
+
yield "","", ""
|
climateqa/engine/keywords.py
CHANGED
|
@@ -11,10 +11,12 @@ class KeywordsOutput(BaseModel):
|
|
| 11 |
|
| 12 |
keywords: list = Field(
|
| 13 |
description="""
|
| 14 |
-
Generate 1 or 2 relevant keywords from the user query to ask a search engine for scientific research papers.
|
|
|
|
| 15 |
|
| 16 |
Example:
|
| 17 |
- "What is the impact of deep sea mining ?" -> ["deep sea mining"]
|
|
|
|
| 18 |
- "How will El Nino be impacted by climate change" -> ["el nino"]
|
| 19 |
- "Is climate change a hoax" -> [Climate change","hoax"]
|
| 20 |
"""
|
|
|
|
| 11 |
|
| 12 |
keywords: list = Field(
|
| 13 |
description="""
|
| 14 |
+
Generate 1 or 2 relevant keywords from the user query to ask a search engine for scientific research papers. Answer only with English keywords.
|
| 15 |
+
Do not use special characters or accents.
|
| 16 |
|
| 17 |
Example:
|
| 18 |
- "What is the impact of deep sea mining ?" -> ["deep sea mining"]
|
| 19 |
+
- "Quel est l'impact de l'exploitation minière en haute mer ?" -> ["deep sea mining"]
|
| 20 |
- "How will El Nino be impacted by climate change" -> ["el nino"]
|
| 21 |
- "Is climate change a hoax" -> [Climate change","hoax"]
|
| 22 |
"""
|
climateqa/knowledge/openalex.py
CHANGED
|
@@ -41,6 +41,10 @@ class OpenAlex():
|
|
| 41 |
break
|
| 42 |
|
| 43 |
df_works = pd.DataFrame(page)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
df_works = df_works.dropna(subset = ["title"])
|
| 45 |
df_works["primary_location"] = df_works["primary_location"].map(replace_nan_with_empty_dict)
|
| 46 |
df_works["abstract"] = df_works["abstract_inverted_index"].apply(lambda x: self.get_abstract_from_inverted_index(x)).fillna("")
|
|
|
|
| 41 |
break
|
| 42 |
|
| 43 |
df_works = pd.DataFrame(page)
|
| 44 |
+
|
| 45 |
+
if df_works.empty:
|
| 46 |
+
return df_works
|
| 47 |
+
|
| 48 |
df_works = df_works.dropna(subset = ["title"])
|
| 49 |
df_works["primary_location"] = df_works["primary_location"].map(replace_nan_with_empty_dict)
|
| 50 |
df_works["abstract"] = df_works["abstract_inverted_index"].apply(lambda x: self.get_abstract_from_inverted_index(x)).fillna("")
|