test-ragp / scripts /self_rag.py
awinml's picture
Upload 11 files
6c5ce7a verified
import argparse
from langchain_community.retrievers import PineconeHybridSearchRetriever
from langchain_core.prompts.chat import ChatPromptTemplate
from langchain_groq import ChatGroq
from rag_pipelines.pipelines.self_rag import SelfRAGPipeline
from rag_pipelines.query_transformer.query_transformer import QueryTransformer
from rag_pipelines.retrieval_evaluator.document_grader import DocumentGrader
from rag_pipelines.retrieval_evaluator.retrieval_evaluator import RetrievalEvaluator
from rag_pipelines.websearch.web_search import WebSearch
def main():
parser = argparse.ArgumentParser(description="Run the Self-RAG pipeline.")
# Pinecone retriever arguments
parser.add_argument("--pinecone_api_key", type=str, required=True, help="Pinecone API key.")
parser.add_argument("--index_name", type=str, default="edgar", help="Pinecone index name.")
parser.add_argument("--dimension", type=int, default=384, help="Dimension of embeddings.")
parser.add_argument("--metric", type=str, default="dotproduct", help="Metric for similarity search.")
parser.add_argument("--region", type=str, default="us-east-1", help="Pinecone region.")
parser.add_argument(
"--namespace",
type=str,
default="edgar-all",
help="Namespace for Pinecone retriever.",
)
# Query Transformer arguments
parser.add_argument(
"--query_transformer_model",
type=str,
default="t5-small",
help="Model used for query transformation.",
)
# Retrieval Evaluator arguments
parser.add_argument(
"--llm_model",
type=str,
default="llama-3.2-90b-vision-preview",
help="Language model name for retrieval evaluator.",
)
parser.add_argument("--llm_api_key", type=str, required=True, help="API key for the language model.")
parser.add_argument(
"--temperature",
type=float,
default=0.7,
help="Temperature for the language model.",
)
parser.add_argument(
"--relevance_threshold",
type=float,
default=0.7,
help="Relevance threshold for document grading.",
)
# Web Search arguments
parser.add_argument("--web_search_api_key", type=str, required=True, help="API key for web search.")
# Prompt arguments
parser.add_argument(
"--prompt_template_path",
type=str,
required=True,
help="Path to the prompt template for LLM.",
)
# Query
parser.add_argument(
"--query",
type=str,
required=True,
help="Query to run through the Self-RAG pipeline.",
)
args = parser.parse_args()
# Initialize Pinecone retriever
retriever = PineconeHybridSearchRetriever(
api_key=args.pinecone_api_key,
index_name=args.index_name,
dimension=args.dimension,
metric=args.metric,
region=args.region,
namespace=args.namespace,
)
# Initialize Query Transformer
query_transformer = QueryTransformer(model_name=args.query_transformer_model)
# Initialize Retrieval Evaluator and Document Grader
retrieval_evaluator = RetrievalEvaluator(
llm_model=args.llm_model,
llm_api_key=args.llm_api_key,
temperature=args.temperature,
)
document_grader = DocumentGrader(
evaluator=retrieval_evaluator,
threshold=args.relevance_threshold,
)
# Initialize Web Search
web_search = WebSearch(api_key=args.web_search_api_key)
# Load the prompt template
with open(args.prompt_template_path) as file:
prompt_template_str = file.read()
prompt = ChatPromptTemplate.from_template(prompt_template_str)
# Initialize the LLM
llm = ChatGroq(
model=args.llm_model,
api_key=args.llm_api_key,
llm_params={"temperature": args.temperature},
)
# Initialize Self-RAG Pipeline
self_rag_pipeline = SelfRAGPipeline(
retriever=retriever,
query_transformer=query_transformer,
retrieval_evaluator=retrieval_evaluator,
document_grader=document_grader,
web_search=web_search,
prompt=prompt,
llm=llm,
)
# Run the pipeline
output = self_rag_pipeline.run(args.query)
print(output)
if __name__ == "__main__":
main()