Spaces:
Runtime error
Runtime error
Refactor LimitedEnsembleRetriever for improved compatibility and functionality
Browse files- Updated the `LimitedEnsembleRetriever` class to remove inheritance from `BaseRetriever`, simplifying its structure.
- Changed the method names to align with the new invoke interface, replacing deprecated methods with `invoke` and `ainvoke`.
- Added a compatibility method to handle both string input and other data types for the `invoke` method, enhancing usability.
- Improved documentation within the class to clarify the purpose and functionality of methods.
- src/rag/vector_store.py +13 -8
src/rag/vector_store.py
CHANGED
|
@@ -16,25 +16,30 @@ from src.core.logging_config import get_logger
|
|
| 16 |
logger = get_logger(__name__)
|
| 17 |
|
| 18 |
|
| 19 |
-
class LimitedEnsembleRetriever
|
| 20 |
-
"""
|
| 21 |
|
| 22 |
def __init__(self, ensemble_retriever: EnsembleRetriever, k: int):
|
| 23 |
-
super().__init__()
|
| 24 |
self.ensemble_retriever = ensemble_retriever
|
| 25 |
self.k = k
|
| 26 |
|
| 27 |
-
def
|
| 28 |
"""Get relevant documents, limited to k results."""
|
| 29 |
-
#
|
| 30 |
-
docs = self.ensemble_retriever.
|
| 31 |
# Limit to k results
|
| 32 |
return docs[:self.k]
|
| 33 |
|
| 34 |
-
async def
|
| 35 |
"""Async version of get_relevant_documents."""
|
| 36 |
-
docs = await self.ensemble_retriever.
|
| 37 |
return docs[:self.k]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
|
| 39 |
|
| 40 |
class VectorStoreManager:
|
|
|
|
| 16 |
logger = get_logger(__name__)
|
| 17 |
|
| 18 |
|
| 19 |
+
class LimitedEnsembleRetriever:
|
| 20 |
+
"""Simple wrapper around EnsembleRetriever that limits total results to k."""
|
| 21 |
|
| 22 |
def __init__(self, ensemble_retriever: EnsembleRetriever, k: int):
|
|
|
|
| 23 |
self.ensemble_retriever = ensemble_retriever
|
| 24 |
self.k = k
|
| 25 |
|
| 26 |
+
def get_relevant_documents(self, query: str) -> List[Document]:
|
| 27 |
"""Get relevant documents, limited to k results."""
|
| 28 |
+
# Use invoke method instead of deprecated get_relevant_documents
|
| 29 |
+
docs = self.ensemble_retriever.invoke(query)
|
| 30 |
# Limit to k results
|
| 31 |
return docs[:self.k]
|
| 32 |
|
| 33 |
+
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
| 34 |
"""Async version of get_relevant_documents."""
|
| 35 |
+
docs = await self.ensemble_retriever.ainvoke(query)
|
| 36 |
return docs[:self.k]
|
| 37 |
+
|
| 38 |
+
def invoke(self, input_data, config=None, **kwargs):
|
| 39 |
+
"""Compatibility method for invoke interface."""
|
| 40 |
+
if isinstance(input_data, str):
|
| 41 |
+
return self.get_relevant_documents(input_data)
|
| 42 |
+
return self.get_relevant_documents(input_data)
|
| 43 |
|
| 44 |
|
| 45 |
class VectorStoreManager:
|