File size: 4,312 Bytes
6c5ce7a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
import argparse
from langchain_community.retrievers import PineconeHybridSearchRetriever
from langchain_core.prompts.chat import ChatPromptTemplate
from langchain_groq import ChatGroq
from rag_pipelines.pipelines.self_rag import SelfRAGPipeline
from rag_pipelines.query_transformer.query_transformer import QueryTransformer
from rag_pipelines.retrieval_evaluator.document_grader import DocumentGrader
from rag_pipelines.retrieval_evaluator.retrieval_evaluator import RetrievalEvaluator
from rag_pipelines.websearch.web_search import WebSearch
def main():
parser = argparse.ArgumentParser(description="Run the Self-RAG pipeline.")
# Pinecone retriever arguments
parser.add_argument("--pinecone_api_key", type=str, required=True, help="Pinecone API key.")
parser.add_argument("--index_name", type=str, default="edgar", help="Pinecone index name.")
parser.add_argument("--dimension", type=int, default=384, help="Dimension of embeddings.")
parser.add_argument("--metric", type=str, default="dotproduct", help="Metric for similarity search.")
parser.add_argument("--region", type=str, default="us-east-1", help="Pinecone region.")
parser.add_argument(
"--namespace",
type=str,
default="edgar-all",
help="Namespace for Pinecone retriever.",
)
# Query Transformer arguments
parser.add_argument(
"--query_transformer_model",
type=str,
default="t5-small",
help="Model used for query transformation.",
)
# Retrieval Evaluator arguments
parser.add_argument(
"--llm_model",
type=str,
default="llama-3.2-90b-vision-preview",
help="Language model name for retrieval evaluator.",
)
parser.add_argument("--llm_api_key", type=str, required=True, help="API key for the language model.")
parser.add_argument(
"--temperature",
type=float,
default=0.7,
help="Temperature for the language model.",
)
parser.add_argument(
"--relevance_threshold",
type=float,
default=0.7,
help="Relevance threshold for document grading.",
)
# Web Search arguments
parser.add_argument("--web_search_api_key", type=str, required=True, help="API key for web search.")
# Prompt arguments
parser.add_argument(
"--prompt_template_path",
type=str,
required=True,
help="Path to the prompt template for LLM.",
)
# Query
parser.add_argument(
"--query",
type=str,
required=True,
help="Query to run through the Self-RAG pipeline.",
)
args = parser.parse_args()
# Initialize Pinecone retriever
retriever = PineconeHybridSearchRetriever(
api_key=args.pinecone_api_key,
index_name=args.index_name,
dimension=args.dimension,
metric=args.metric,
region=args.region,
namespace=args.namespace,
)
# Initialize Query Transformer
query_transformer = QueryTransformer(model_name=args.query_transformer_model)
# Initialize Retrieval Evaluator and Document Grader
retrieval_evaluator = RetrievalEvaluator(
llm_model=args.llm_model,
llm_api_key=args.llm_api_key,
temperature=args.temperature,
)
document_grader = DocumentGrader(
evaluator=retrieval_evaluator,
threshold=args.relevance_threshold,
)
# Initialize Web Search
web_search = WebSearch(api_key=args.web_search_api_key)
# Load the prompt template
with open(args.prompt_template_path) as file:
prompt_template_str = file.read()
prompt = ChatPromptTemplate.from_template(prompt_template_str)
# Initialize the LLM
llm = ChatGroq(
model=args.llm_model,
api_key=args.llm_api_key,
llm_params={"temperature": args.temperature},
)
# Initialize Self-RAG Pipeline
self_rag_pipeline = SelfRAGPipeline(
retriever=retriever,
query_transformer=query_transformer,
retrieval_evaluator=retrieval_evaluator,
document_grader=document_grader,
web_search=web_search,
prompt=prompt,
llm=llm,
)
# Run the pipeline
output = self_rag_pipeline.run(args.query)
print(output)
if __name__ == "__main__":
main()
|