Spaces:
Sleeping
Sleeping
| import os | |
| # https://stackoverflow.com/questions/76175046/how-to-add-prompt-to-langchain-conversationalretrievalchain-chat-over-docs-with | |
| # again from: | |
| # https://python.langchain.com/docs/integrations/providers/vectara/vectara_chat | |
| from langchain.document_loaders import PyPDFDirectoryLoader | |
| import pandas as pd | |
| import langchain | |
| from queue import Queue | |
| from typing import Any | |
| from langchain.llms.huggingface_text_gen_inference import HuggingFaceTextGenInference | |
| from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler | |
| from langchain.schema import LLMResult | |
| from langchain.embeddings import HuggingFaceEmbeddings | |
| from langchain.vectorstores import FAISS | |
| from langchain.prompts.prompt import PromptTemplate | |
| from anyio.from_thread import start_blocking_portal #For model callback streaming | |
| from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate | |
| import os | |
| from dotenv import load_dotenv | |
| import streamlit as st | |
| import json | |
| from langchain.document_loaders import PyPDFLoader | |
| from langchain.text_splitter import CharacterTextSplitter | |
| from langchain.embeddings import OpenAIEmbeddings | |
| from langchain.chains.question_answering import load_qa_chain | |
| from langchain.chat_models import ChatOpenAI | |
| # from langchain.chat_models import ChatAnthropic | |
| from langchain_anthropic import ChatAnthropic | |
| from langchain.vectorstores import Chroma | |
| import chromadb | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain.llms import OpenAI | |
| from langchain.chains import RetrievalQA | |
| from langchain.document_loaders import TextLoader | |
| from langchain.document_loaders import DirectoryLoader | |
| from langchain_community.document_loaders import PyMuPDFLoader | |
| from langchain.schema import Document | |
| from langchain.memory import ConversationBufferMemory | |
| from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_PROMPT | |
| from langchain.chains.conversational_retrieval.prompts import QA_PROMPT | |
| import gradio as gr | |
| from langchain.memory import ConversationBufferMemory | |
| from langchain.chains import ConversationalRetrievalChain | |
| print("Started") | |
| def get_species_list_from_db(db_name): | |
| embedding = OpenAIEmbeddings() | |
| vectordb_temp = Chroma(persist_directory=db_name, | |
| embedding_function=embedding) | |
| species_list=[] | |
| for meta in vectordb_temp.get()["metadatas"] : | |
| try: | |
| matched_first_species = meta['matched_specie_0'] | |
| except KeyError: | |
| continue | |
| # Since each document is considered as a single chunk, the chunk_index is 0 for all | |
| species_list.append( matched_first_species) | |
| return species_list | |
| # default_persist_directory = './db5' # For deployement | |
| default_persist_directory_insects='./vector-databases-deployed/db5-agllm-data-isu-field-insects-all-species' | |
| default_persist_directory_weeds='./vector-databases-deployed/db5-agllm-data-isu-field-weeds-all-species' | |
| species_list_insects=get_species_list_from_db(default_persist_directory_insects) | |
| species_list_weeds=get_species_list_from_db(default_persist_directory_weeds) | |
| # default_persist_directory = 'vector-databases/db5-pre-completion' # For Development | |
| csv_filepath1 = "./agllm-data/corrected/Corrected_supplemented-insect_data-2500-sorted.xlsx" | |
| csv_filepath2 = "./agllm-data/corrected/Corrected_supplemented-insect_data-remaining.xlsx" | |
| model_name=4 | |
| max_tokens=400 | |
| system_message = {"role": "system", "content": "You are a helpful assistant."} # TODO: double check how this plays out later. | |
| langchain.debug=False # TODO: DOUBLE CHECK | |
| from langchain import globals | |
| globals.set_debug(False) | |
| retriever_k_value=3 | |
| embedding = OpenAIEmbeddings() | |
| print("Started....") | |
| class ChatOpenRouter(ChatOpenAI): | |
| openai_api_base: str | |
| openai_api_key: str | |
| model_name: str | |
| def __init__(self, | |
| model_name: str, | |
| openai_api_key: [str] = None, | |
| openai_api_base: str = "https://openrouter.ai/api/v1", | |
| **kwargs): | |
| openai_api_key = openai_api_key or os.getenv('OPENROUTER_API_KEY') | |
| super().__init__(openai_api_base=openai_api_base, | |
| openai_api_key=openai_api_key, | |
| model_name=model_name, **kwargs) | |
| ######### todo: skipping the first step | |
| # print(# Single example | |
| # vectordb.as_retriever(k=2, search_kwargs={"filter": {"matched_specie_0": "Hypagyrtis unipunctata"}, 'k':1}).get_relevant_documents( | |
| # "Checking if retriever is correctly initalized?" | |
| # )) | |
| columns = ['species', 'common name', 'order', 'family', | |
| 'genus', 'Updated role in ecosystem', 'Proof', | |
| 'ipm strategies', 'size of insect', 'geographical spread', | |
| 'life cycle specifics', 'pest for plant species', 'species status', | |
| 'distribution area', 'appearance', 'identification'] | |
| df1 = pd.read_excel(csv_filepath1, usecols=columns) | |
| df2 = pd.read_excel(csv_filepath2, usecols=columns) | |
| all_insects_data = pd.concat([df1, df2], ignore_index=True) | |
| def get_prompt_with_vetted_info_from_specie_name(search_for_specie, mode): | |
| def read_and_format_filtered_csv_better(insect_specie): | |
| filtered_data = all_insects_data[all_insects_data['species'] == insect_specie] | |
| formatted_data = "" | |
| # Format the filtered data | |
| for index, row in filtered_data.iterrows(): | |
| row_data = [f"{col}: {row[col]}" for col in filtered_data.columns] | |
| formatted_row = "\n".join(row_data) | |
| formatted_data += f"{formatted_row}\n" | |
| return formatted_data | |
| # Use the path to your CSV file here | |
| vetted_info=read_and_format_filtered_csv_better(search_for_specie) | |
| if mode=="Farmer": | |
| language_constraint="The language should be acustomed to the Farmers. Given question is likely to be asked by a farmer in the field will ask which will help to make decisions which are immediate and practical." | |
| elif mode=="Researcher": | |
| language_constraint="The language should be acustomed to a researcher. Given question is likely to be asked by a scientist which are comprehensive and aimed at exploring new knowledge or refining existing methodologies" | |
| else: | |
| print("No valid mode provided. Exiting") | |
| exit() | |
| # general_system_template = """ | |
| # In every question you are provided information about the insect/weed. Two types of information are: First, Vetted Information (which is same in every questinon) and Second, some context from external documents about an insect/weed species and a question by the user. answer the question according to these two types of informations. | |
| # ---- | |
| # Vetted info is as follows: | |
| # {vetted_info} | |
| # ---- | |
| # The context retrieved for documents about this particular question is as follows: | |
| # {context} | |
| # ---- | |
| # Additional Instruction: | |
| # 1. Reference Constraint | |
| # At the end of each answer provide the source/reference for the given data in following format: | |
| # \n\n[enter two new lines before writing below] References: | |
| # Vetted Information Used: Write what was used from the document for coming up with the answer above. Write exact part of lines. If nothing, write 'Nothing'. | |
| # Documents Used: Write what was used from the document for coming up with the answer above. If nothing, write 'Nothing'. Write exact part of lines and document used. | |
| # 2. Information Constraint: | |
| # Only answer the question from information provided otherwise say you dont know. You have to answer in 50 words including references. Prioritize information in documents/context over vetted information. And first mention the warnings/things to be careful about. | |
| # 3. Language constraint: | |
| # {language_constraint} | |
| # ---- | |
| # """.format(vetted_info=vetted_info, language_constraint=language_constraint,context="{context}", ) | |
| general_system_template = f""" | |
| You are an AI assistant specialized in providing information about insects/weeds. Answer the user's question based on the available information or your general knowledge. | |
| The context retrieved for this question is as follows: | |
| {{context}} | |
| Instructions: | |
| 1. Evaluate the relevance of the provided context to the question. | |
| 2. If the context contains relevant information, use it to answer the question. | |
| 3. If the context does not contain relevant information, use your general knowledge to answer the question. | |
| 4. Format your response as follows: | |
| Answer: Provide a concise answer in less than 50 words. | |
| Reference: If you used the provided context, cite the specific information used. If you used your general knowledge, state "Based on general knowledge". | |
| 5. Language constraint: | |
| {language_constraint} | |
| Question: {{question}} | |
| """ | |
| general_user_template = "Question:```{question}```" | |
| messages_formatted = [ | |
| SystemMessagePromptTemplate.from_template(general_system_template), | |
| HumanMessagePromptTemplate.from_template(general_user_template) | |
| ] | |
| qa_prompt = ChatPromptTemplate.from_messages( messages_formatted ) | |
| # print(qa_prompt) | |
| return qa_prompt | |
| qa_prompt=get_prompt_with_vetted_info_from_specie_name("Papaipema nebris", "Researcher") | |
| # print("First prompt is intialized as: " , qa_prompt, "\n\n") | |
| memory = ConversationBufferMemory(memory_key="chat_history",output_key='answer', return_messages=True) # https://github.com/langchain-ai/langchain/issues/9394#issuecomment-1683538834 | |
| if model_name==4: | |
| llm_openai = ChatOpenAI(model_name="gpt-4-1106-preview" , temperature=0, max_tokens=max_tokens) # TODO: NEW MODEL VERSION AVAILABLE | |
| else: | |
| llm_openai = ChatOpenAI(model_name="gpt-3.5-turbo-0125" , temperature=0, max_tokens=max_tokens) | |
| specie_selector="Papaipema nebris" | |
| filter = { | |
| "$or": [ | |
| {"matched_specie_0": specie_selector}, | |
| {"matched_specie_1": specie_selector}, | |
| {"matched_specie_2": specie_selector}, | |
| ] | |
| } | |
| # retriever = vectordb.as_retriever(search_kwargs={'k':retriever_k_value, 'filter': filter}) | |
| # qa_chain = ConversationalRetrievalChain.from_llm( | |
| # llm_openai, retriever, memory=memory, verbose=False, return_source_documents=True,\ | |
| # combine_docs_chain_kwargs={'prompt': qa_prompt} | |
| # ) | |
| # | |
| def initialize_qa_chain(specie_selector, application_mode, model_name="GPT-4", database_persistent_directory=default_persist_directory_insects): | |
| if model_name=="GPT-4": | |
| chosen_llm=ChatOpenAI(model_name="gpt-4-1106-preview" , temperature=0, max_tokens=max_tokens) | |
| elif model_name=="GPT-3.5": | |
| chosen_llm=ChatOpenAI(model_name="gpt-3.5-turbo-0125" , temperature=0, max_tokens=max_tokens) | |
| elif model_name=="Llama-3 70B": | |
| chosen_llm = ChatOpenRouter(model_name="meta-llama/llama-3-70b-instruct", temperature=0,max_tokens=max_tokens ) | |
| elif model_name=="Llama-3 8B": | |
| chosen_llm = ChatOpenRouter(model_name="meta-llama/llama-3-8b-instruct", temperature=0, max_tokens=max_tokens) | |
| elif model_name=="Gemini-1.5 Pro": | |
| chosen_llm = ChatOpenRouter(model_name="google/gemini-pro-1.5", temperature=0, max_tokens=max_tokens) | |
| elif model_name=="Claude 3 Opus": | |
| chosen_llm = ChatAnthropic(model_name='claude-3-opus-20240229', temperature=0, max_tokens=max_tokens) | |
| else: | |
| print("No appropriate llm was selected") | |
| exit() | |
| filter = { | |
| "$or": [ | |
| {"matched_specie_0": specie_selector}, | |
| {"matched_specie_1": specie_selector}, | |
| {"matched_specie_2": specie_selector}, | |
| {"matched_specie_3": specie_selector}, | |
| {"matched_specie_4": specie_selector}, | |
| {"matched_specie_5": specie_selector}, | |
| {"matched_specie_6": specie_selector}, | |
| {"matched_specie_7": specie_selector}, | |
| {"matched_specie_8": specie_selector}, | |
| {"matched_specie_9": specie_selector}, | |
| {"matched_specie_10": specie_selector} | |
| ] | |
| } | |
| embedding = OpenAIEmbeddings() | |
| vectordb = Chroma(persist_directory=database_persistent_directory, | |
| embedding_function=embedding) | |
| print("got updated retriever without metadata filtering") | |
| retriever = vectordb.as_retriever(search_kwargs={'k':retriever_k_value, 'filter': filter}) | |
| memory = ConversationBufferMemory(memory_key="chat_history", output_key='answer', return_messages=True) | |
| qa_prompt=get_prompt_with_vetted_info_from_specie_name(specie_selector, application_mode) | |
| qa_chain = ConversationalRetrievalChain.from_llm( | |
| chosen_llm, retriever, memory=memory, verbose=False, return_source_documents=True, | |
| combine_docs_chain_kwargs={'prompt': qa_prompt} | |
| ) | |
| return qa_chain | |
| # result = qa_chain.invoke({"question": "where are stalk borer eggs laid?"}) | |
| # print("Got the first LLM task working: ", result) | |
| #Application Interface: | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown( | |
| """ | |
|  | |
| """ | |
| ) | |
| with gr.Column(scale=1): | |
| gr.Markdown( | |
| """ | |
|  | |
| """ | |
| ) | |
| # Configure UI layout | |
| chatbot = gr.Chatbot(height=600, label="AgLLM") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| with gr.Row(): | |
| domain_name = gr.Dropdown( | |
| list(["Insects", "Weeds"]), | |
| value="Insects", | |
| label="Domain", | |
| info="Select Domain", | |
| interactive=True, | |
| scale=1, | |
| visible=True | |
| ) | |
| # Model selection | |
| specie_selector = gr.Dropdown( | |
| species_list_insects, | |
| value=species_list_insects[0], | |
| label="Species", | |
| info="Select the Species", | |
| interactive=True, | |
| scale=1, | |
| visible=True | |
| ) | |
| with gr.Row(): | |
| model_name = gr.Dropdown( | |
| list(["GPT-4", "GPT-3.5", "Llama-3 70B", "Llama-3 8B", "Gemini-1.5 Pro", "Claude 3 Opus"]), | |
| value="Llama-3 70B", | |
| label="LLM", | |
| info="Select the LLM", | |
| interactive=True, | |
| scale=1, | |
| visible=True | |
| ) | |
| application_mode = gr.Dropdown( | |
| list(["Farmer", "Researcher"]), | |
| value="Researcher", | |
| label="Mode", | |
| info="Select the Mode", | |
| interactive=True, | |
| scale=1, | |
| visible=True | |
| ) | |
| with gr.Column(scale=2): | |
| # User input prompt text field | |
| user_prompt_message = gr.Textbox(placeholder="Please add user prompt here", label="User prompt") | |
| with gr.Row(): | |
| # clear = gr.Button("Clear Conversation", scale=2) | |
| submitBtn = gr.Button("Submit", scale=8) | |
| state = gr.State([]) | |
| qa_chain_state = gr.State(value=None) | |
| # Handle user message | |
| def user(user_prompt_message, history): | |
| # print("HISTORY IS: ", history) # TODO: REMOVE IT LATER | |
| if user_prompt_message != "": | |
| return history + [[user_prompt_message, None]] | |
| else: | |
| return history + [["Invalid prompts - user prompt cannot be empty", None]] | |
| # Chatbot logic for configuration, sending the prompts, rendering the streamed back generations, etc. | |
| def bot(model_name, application_mode, user_prompt_message, history, messages_history, qa_chain, domain_name): | |
| if qa_chain == None: | |
| qa_chain=init_qa_chain(species_list_insects[0], application_mode, model_name, domain_name) | |
| dialog = [] | |
| bot_message = "" | |
| history[-1][1] = "" # Placeholder for the answer | |
| dialog = [ | |
| {"role": "user", "content": user_prompt_message}, | |
| ] | |
| messages_history += dialog | |
| # Queue for streamed character rendering | |
| q = Queue() | |
| # Async task for streamed chain results wired to callbacks we previously defined, so we don't block the UI | |
| def task(user_prompt_message): | |
| result = qa_chain.invoke({"question": user_prompt_message}) | |
| answer = result["answer"] | |
| try: | |
| answer_start = answer.find("Answer:") | |
| reference_start = answer.find("Reference:") | |
| if answer_start != -1 and reference_start != -1: | |
| model_answer = answer[answer_start + len("Answer:"):reference_start].strip() | |
| reference = answer[reference_start + len("Reference:"):].strip() | |
| formatted_response = f"Answer:\n{model_answer}\n\nReferences:\n{reference}" | |
| else: | |
| formatted_response = answer | |
| except: | |
| print(f"Error parsing so displaying the raw output") | |
| formatted_response = answer | |
| return formatted_response | |
| history[-1][1] = task(user_prompt_message) | |
| return [history, messages_history] | |
| # Initialize the chat history with default system message | |
| def init_history(messages_history): | |
| messages_history = [] | |
| messages_history += [system_message] | |
| return messages_history | |
| # Clean up the user input text field | |
| def input_cleanup(): | |
| return "" | |
| def init_qa_chain(specie_selector, application_mode, model_name, domain_name): | |
| if domain_name=="Insects": | |
| qa_chain = initialize_qa_chain(specie_selector, application_mode, model_name, default_persist_directory_insects) | |
| elif domain_name=="Weeds": | |
| qa_chain = initialize_qa_chain(specie_selector, application_mode, model_name, default_persist_directory_weeds) | |
| else: | |
| print("No Appropriate Chain Selected") | |
| return qa_chain | |
| specie_selector.change( | |
| init_qa_chain, | |
| inputs=[specie_selector, application_mode,model_name, domain_name ], | |
| outputs=[qa_chain_state] | |
| ) | |
| model_name.change( | |
| init_qa_chain, | |
| inputs=[specie_selector, application_mode,model_name, domain_name ], | |
| outputs=[qa_chain_state] | |
| ) | |
| ##### | |
| def update_species_list(domain): | |
| if domain == "Insects": | |
| return gr.Dropdown( species_list_insects, value=species_list_insects[0], label="Species", info="Select the Species", interactive=True, scale=1, visible=True ) | |
| elif domain == "Weeds": | |
| return gr.Dropdown( species_list_weeds, value=species_list_weeds[0], label="Species", info="Select the Species", interactive=True, scale=1, visible=True ) | |
| domain_name.change( | |
| update_species_list, | |
| inputs=[domain_name], | |
| outputs=[specie_selector] | |
| ) | |
| # When the user clicks Enter and the user message is submitted | |
| user_prompt_message.submit( | |
| user, | |
| [user_prompt_message, chatbot], | |
| [chatbot], | |
| queue=False | |
| ).then( | |
| bot, | |
| [model_name, application_mode, user_prompt_message, chatbot, state, qa_chain_state, domain_name], | |
| [chatbot, state] | |
| ).then(input_cleanup, | |
| [], | |
| [user_prompt_message], | |
| queue=False | |
| ) | |
| # When the user clicks the submit button | |
| submitBtn.click( | |
| user, | |
| [user_prompt_message, chatbot], | |
| [chatbot], | |
| queue=False | |
| ).then( | |
| bot, | |
| [model_name, application_mode, user_prompt_message, chatbot, state, qa_chain_state, domain_name], | |
| [chatbot, state] | |
| ).then( | |
| input_cleanup, | |
| [], | |
| [user_prompt_message], | |
| queue=False | |
| ) | |
| # When the user clicks the clear button | |
| # clear.click(lambda: None, None, chatbot, queue=False).success(init_history, [state], [state]) | |
| if __name__ == "__main__": | |
| # demo.launch() | |
| demo.queue().launch(allowed_paths=["/"], share=False, show_error=True) |