Spaces:
Sleeping
Sleeping
| from langchain.chains import ConversationalRetrievalChain | |
| from langchain.memory import ConversationBufferMemory | |
| from langchain.prompts import PromptTemplate | |
| from langchain_community.embeddings import HuggingFaceEmbeddings | |
| from langchain_community.vectorstores import FAISS | |
| from langchain.llms.base import LLM | |
| from groq import Groq | |
| from typing import Any, List, Optional, Dict | |
| from pydantic import Field, BaseModel | |
| import os | |
| class GroqLLM(LLM, BaseModel): | |
| groq_api_key: str = Field(..., description="Groq API Key") | |
| model_name: str = Field(default="llama-3.3-70b-versatile", description="Model name to use") | |
| client: Optional[Any] = None | |
| def __init__(self, **data): | |
| super().__init__(**data) | |
| self.client = Groq(api_key=self.groq_api_key) | |
| def _llm_type(self) -> str: | |
| return "groq" | |
| def _call(self, prompt: str, stop: Optional[List[str]] = None, **kwargs: Any) -> str: | |
| completion = self.client.chat.completions.create( | |
| messages=[{"role": "user", "content": prompt}], | |
| model=self.model_name, | |
| **kwargs | |
| ) | |
| return completion.choices[0].message.content | |
| def _identifying_params(self) -> Dict[str, Any]: | |
| """Get the identifying parameters.""" | |
| return { | |
| "model_name": self.model_name | |
| } | |
| class AutismResearchBot: | |
| def __init__(self, groq_api_key: str, index_path: str = "index.faiss"): | |
| # Initialize the Groq LLM | |
| self.llm = GroqLLM( | |
| groq_api_key=groq_api_key, | |
| model_name="llama-3.3-70b-versatile" # You can adjust the model as needed | |
| ) | |
| # Load the FAISS index | |
| self.embeddings = HuggingFaceEmbeddings( | |
| model_name="pritamdeka/S-PubMedBert-MS-MARCO", | |
| model_kwargs={'device': 'cpu'} | |
| ) | |
| self.db = FAISS.load_local("./", self.embeddings, allow_dangerous_deserialization = True) | |
| # Initialize memory | |
| self.memory = ConversationBufferMemory( | |
| memory_key="chat_history", | |
| return_messages=True, | |
| output_key = "answer" | |
| ) | |
| # Create the RAG chain | |
| self.qa_chain = self._create_qa_chain() | |
| def _create_qa_chain(self): | |
| # Define the prompt template | |
| template = """You are an expert AI assistant specialized in autism research and diagnostics. You have access to a database of scientific papers, research documents, and diagnostic tools about autism. Use this knowledge to conduct a structured assessment and provide evidence-based therapy recommendations. | |
| Context from scientific papers (use these details only for final therapy recommendations): | |
| {context} | |
| Chat History: | |
| {chat_history} | |
| Objective: | |
| - Gather demographic information | |
| - Present autism types for initial self-identification | |
| - Conduct detailed assessment through naturalistic conversation | |
| - Provide evidence-based therapy recommendations | |
| Instructions: | |
| 1. Begin with collecting age and gender | |
| 2. Present main types of autism with brief descriptions | |
| 3. Ask targeted questions with relatable examples | |
| 4. Maintain a conversational, empathetic tone | |
| 5. Conclude with personalized therapy recommendations | |
| Initial Introduction: | |
| "Hello, I am an AI assistant specialized in autism research and diagnostics. To provide you with the most appropriate guidance, I'll need to gather some information. Let's start with some basic details: | |
| 1. Could you share the age and gender of the person seeking assessment? | |
| Once you provide these details, I'll share some common types of autism spectrum conditions, and we can discuss which ones seem most relevant to your experience." | |
| After receiving demographic information, present autism types: | |
| "Thank you. There are several types of autism spectrum conditions. Please let me know which of these seems most relevant to your situation: | |
| 1. Social Communication Challenges | |
| Example: Difficulty maintaining conversations, understanding social cues | |
| 2. Repetitive Behavior Patterns | |
| Example: Strong adherence to routines, specific intense interests | |
| 3. Sensory Processing Differences | |
| Example: Sensitivity to sounds, lights, or textures | |
| 4. Language Development Variations | |
| Example: Delayed speech, unique communication patterns | |
| 5. Executive Function Challenges | |
| Example: Difficulty with planning, organizing, and transitioning between tasks | |
| Which of these patterns feels most familiar to your experience?" | |
| Follow-up Questions Format: | |
| "I understand you identify most with [selected type]. Let me ask you about some specific experiences: | |
| [Question with example] | |
| For instance: When you're in a social situation, do you find yourself [specific example from daily life]?" | |
| Continue natural conversation flow with examples for each question: | |
| - Include real-life scenarios | |
| - Relate questions to age-appropriate situations | |
| - Provide clear, concrete examples | |
| - Allow for open-ended responses | |
| Final Assessment and Therapy Recommendations: | |
| "Based on our detailed discussion and the patterns you've described, I can now share some evidence-based therapy recommendations tailored to your specific needs..." | |
| Question: | |
| {question} | |
| Answer:""" | |
| PROMPT = PromptTemplate( | |
| template=template, | |
| input_variables=["context", "chat_history", "question"] | |
| ) | |
| # Create the chain | |
| chain = ConversationalRetrievalChain.from_llm( | |
| llm=self.llm, | |
| chain_type="stuff", | |
| retriever=self.db.as_retriever( | |
| search_type="similarity", | |
| search_kwargs={"k": 3} | |
| ), | |
| memory=self.memory, | |
| combine_docs_chain_kwargs={ | |
| "prompt": PROMPT | |
| }, | |
| # verbose = True, | |
| return_source_documents=True | |
| ) | |
| return chain | |
| def answer_question(self, question: str): | |
| """ | |
| Process a question and return the answer along with source documents | |
| """ | |
| result = self.qa_chain({"question": question}) | |
| # Extract answer and sources | |
| answer = result['answer'] | |
| sources = result['source_documents'] | |
| # Format sources for reference | |
| source_info = [] | |
| for doc in sources: | |
| source_info.append({ | |
| 'content': doc.page_content[:200] + "...", | |
| 'metadata': doc.metadata | |
| }) | |
| return { | |
| 'answer': answer, | |
| 'sources': source_info | |
| } |