Bases: BaseRetriever
Retriever wrapper that applies query rewriting before retrieval.
This wrapper intercepts queries, applies domain-specific expansions
to improve semantic search, then delegates to the underlying retriever.
Source code in src/augmentation/components/retrievers/query_rewriting_retriever.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73 | class QueryRewritingRetriever(BaseRetriever):
"""Retriever wrapper that applies query rewriting before retrieval.
This wrapper intercepts queries, applies domain-specific expansions
to improve semantic search, then delegates to the underlying retriever.
"""
def __init__(
self,
base_retriever: BaseRetriever,
query_rewriter: Optional[QueryRewriter] = None,
logger: Optional[logging.Logger] = None,
):
"""Initialize the query rewriting retriever.
Args:
base_retriever: The underlying retriever to delegate to
query_rewriter: Query rewriter instance (creates default if None)
logger: Logger instance
"""
super().__init__()
self._base_retriever = base_retriever
self._query_rewriter = query_rewriter or QueryRewriter()
self._logger = logger or LoggerConfiguration.get_logger(__name__)
def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
"""Retrieve nodes with query rewriting.
Args:
query_bundle: Query bundle containing the user query
Returns:
List of nodes with scores from the underlying retriever
"""
original_query = query_bundle.query_str
# Rewrite query if it matches a pattern
rewritten_query = self._query_rewriter.rewrite(original_query)
# If query was rewritten, create new query bundle
if rewritten_query != original_query:
self._logger.info(
f"[QueryRewritingRetriever] Query rewritten\n"
f" Original: {original_query[:100]}...\n"
f" Rewritten: {rewritten_query[:150]}..."
)
query_bundle = QueryBundle(
query_str=rewritten_query,
custom_embedding_strs=[rewritten_query],
)
else:
self._logger.debug("[QueryRewritingRetriever] No rewriting applied")
# Delegate to base retriever
return self._base_retriever._retrieve(query_bundle)
async def _aretrieve(
self, query_bundle: QueryBundle
) -> List[NodeWithScore]:
"""Async retrieve (delegates to sync for now)."""
return self._retrieve(query_bundle)
|
__init__(base_retriever, query_rewriter=None, logger=None)
Initialize the query rewriting retriever.
| Parameters: |
-
base_retriever
(BaseRetriever)
–
The underlying retriever to delegate to
-
query_rewriter
(Optional[QueryRewriter], default:
None
)
–
Query rewriter instance (creates default if None)
-
logger
(Optional[Logger], default:
None
)
–
|
Source code in src/augmentation/components/retrievers/query_rewriting_retriever.py
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36 | def __init__(
self,
base_retriever: BaseRetriever,
query_rewriter: Optional[QueryRewriter] = None,
logger: Optional[logging.Logger] = None,
):
"""Initialize the query rewriting retriever.
Args:
base_retriever: The underlying retriever to delegate to
query_rewriter: Query rewriter instance (creates default if None)
logger: Logger instance
"""
super().__init__()
self._base_retriever = base_retriever
self._query_rewriter = query_rewriter or QueryRewriter()
self._logger = logger or LoggerConfiguration.get_logger(__name__)
|