Query_rewriting_retriever

This module contains functionality related to the the query_rewriting_retriever module for augmentation.components.retrievers.

Query_rewriting_retriever

Query rewriting retriever wrapper that enhances retrieval with query expansion.

QueryRewritingRetriever

Bases: BaseRetriever

Retriever wrapper that applies query rewriting before retrieval.

This wrapper intercepts queries, applies domain-specific expansions to improve semantic search, then delegates to the underlying retriever.

Source code in src/augmentation/components/retrievers/query_rewriting_retriever.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
class QueryRewritingRetriever(BaseRetriever):
    """Retriever wrapper that applies query rewriting before retrieval.

    This wrapper intercepts queries, applies domain-specific expansions
    to improve semantic search, then delegates to the underlying retriever.
    """

    def __init__(
        self,
        base_retriever: BaseRetriever,
        query_rewriter: Optional[QueryRewriter] = None,
        logger: Optional[logging.Logger] = None,
    ):
        """Initialize the query rewriting retriever.

        Args:
            base_retriever: The underlying retriever to delegate to
            query_rewriter: Query rewriter instance (creates default if None)
            logger: Logger instance
        """
        super().__init__()
        self._base_retriever = base_retriever
        self._query_rewriter = query_rewriter or QueryRewriter()
        self._logger = logger or LoggerConfiguration.get_logger(__name__)

    def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
        """Retrieve nodes with query rewriting.

        Args:
            query_bundle: Query bundle containing the user query

        Returns:
            List of nodes with scores from the underlying retriever
        """
        original_query = query_bundle.query_str

        # Rewrite query if it matches a pattern
        rewritten_query = self._query_rewriter.rewrite(original_query)

        # If query was rewritten, create new query bundle
        if rewritten_query != original_query:
            self._logger.info(
                f"[QueryRewritingRetriever] Query rewritten\n"
                f"  Original: {original_query[:100]}...\n"
                f"  Rewritten: {rewritten_query[:150]}..."
            )
            query_bundle = QueryBundle(
                query_str=rewritten_query,
                custom_embedding_strs=[rewritten_query],
            )
        else:
            self._logger.debug("[QueryRewritingRetriever] No rewriting applied")

        # Delegate to base retriever
        return self._base_retriever._retrieve(query_bundle)

    async def _aretrieve(
        self, query_bundle: QueryBundle
    ) -> List[NodeWithScore]:
        """Async retrieve (delegates to sync for now)."""
        return self._retrieve(query_bundle)

__init__(base_retriever, query_rewriter=None, logger=None)

Initialize the query rewriting retriever.

Parameters:
  • base_retriever (BaseRetriever) –

    The underlying retriever to delegate to

  • query_rewriter (Optional[QueryRewriter], default: None ) –

    Query rewriter instance (creates default if None)

  • logger (Optional[Logger], default: None ) –

    Logger instance

Source code in src/augmentation/components/retrievers/query_rewriting_retriever.py
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
def __init__(
    self,
    base_retriever: BaseRetriever,
    query_rewriter: Optional[QueryRewriter] = None,
    logger: Optional[logging.Logger] = None,
):
    """Initialize the query rewriting retriever.

    Args:
        base_retriever: The underlying retriever to delegate to
        query_rewriter: Query rewriter instance (creates default if None)
        logger: Logger instance
    """
    super().__init__()
    self._base_retriever = base_retriever
    self._query_rewriter = query_rewriter or QueryRewriter()
    self._logger = logger or LoggerConfiguration.get_logger(__name__)