from dotenv import load_dotenv from langchain_classic.retrievers import ContextualCompressionRetriever from langchain_core.output_parsers import StrOutputParser from langchain_core.prompts import ChatPromptTemplate from langchain_core.runnables import RunnableLambda, RunnablePassthrough from langchain_google_community import VertexAISearchRetriever from langchain_google_community.vertex_rank import VertexAIRank from langchain_google_genai import ChatGoogleGenerativeAI load_dotenv() PROJECT = "akqa-ita-ai-poc1" DATA_STORE = "akern-ds_1771234036654" MODEL = "gemini-2.5-flash" LOCATION = "eu" PRINT_SOURCES = False MAX_OUTPUT_TOKENS = 65535 class RagChain: def __init__( self, top_k: int, top_p: float, temperature: float, retriever_max_docs: int, reranker_max_results: int, ) -> None: self.top_k = top_k self.top_p = top_p self.temperature = temperature self.retriever_max_docs = retriever_max_docs self.reranker_max_results = reranker_max_results with open("prompt.md") as f: question_template = f.read() with open("question_rewrite_prompt.md") as f: question_rewrite_template = f.read() question_prompt = ChatPromptTemplate.from_template(question_template) question_rewrite_prompt = ChatPromptTemplate.from_template( question_rewrite_template ) self._retriever_sources: list[dict] = [] self._reranked_sources: list[dict] = [] self._llm = ChatGoogleGenerativeAI( model=MODEL, project=PROJECT, vertexai=True, top_p=self.top_p, top_k=self.top_k, temperature=self.temperature, max_output_tokens=MAX_OUTPUT_TOKENS, ) self._base_retriever = VertexAISearchRetriever( project_id=PROJECT, data_store_id=DATA_STORE, max_documents=self.retriever_max_docs, location_id=LOCATION, beta=True, ) self._reranker = VertexAIRank( project_id=PROJECT, location_id=LOCATION, ranking_config="default_ranking_config", top_n=self.reranker_max_results, ) self._compression_retriever = ContextualCompressionRetriever( base_compressor=self._reranker, base_retriever=self._base_retriever ) question_rewrite_chain = ( {"question": RunnablePassthrough()} | question_rewrite_prompt | self._llm | StrOutputParser() | RunnableLambda(self._log_rewritten_question) ) rag_chain = ( { "context": RunnableLambda(self._format_docs), "question": RunnablePassthrough(), } | question_prompt | self._llm | StrOutputParser() ) self._full_chain = question_rewrite_chain | rag_chain def _log_rewritten_question(self, rewritten_question: str) -> str: return rewritten_question def _format_docs(self, question: str) -> str: retrieved_docs = self._base_retriever.invoke(question) reranked_docs = self._compression_retriever.invoke(question) self._retriever_sources = [ { "page_content": f"{doc.page_content[:50]}...", "source": doc.metadata.get("source", ""), } for doc in retrieved_docs ] # Build a lookup map from page_content -> source using the original # retrieved docs, because VertexAIRank strips metadata (including source) # from the documents it returns. source_lookup: dict[str, str] = { doc.page_content: doc.metadata.get("source", "") for doc in retrieved_docs } self._reranked_sources = [ { "relevance_score": doc.metadata.get("relevance_score", ""), "page_content": f"{doc.page_content[:50]}...", "source": source_lookup.get(doc.page_content, ""), } for doc in reranked_docs ] if PRINT_SOURCES: print("========== RETRIEVER DOCUMENTS ==========") for idx, doc in enumerate(retrieved_docs, start=1): snippet = doc.page_content[:200].replace("\n", " ") print( f"[{idx}] metadata={doc.metadata['source']} | snippet=...{snippet}..." ) print("========== RERANKED DOCUMENTS ==========") for idx, doc in enumerate(reranked_docs, start=1): snippet = doc.page_content[:200].replace("\n", " ") source = source_lookup.get(doc.page_content, "") print( f"[{idx}] source={source} | relevance_score={doc.metadata.get('relevance_score', '')} | snippet=...{snippet}..." ) return "\n\n".join(doc.page_content for doc in reranked_docs) def getSources(self) -> list[dict]: return list(self._retriever_sources) def getRankedSources(self) -> list[dict]: return list(self._reranked_sources) def stream(self, message: str): return self._full_chain.astream(message)