Files
AKERN-Langchain/chain.py
2026-02-18 16:09:31 +01:00

161 lines
5.4 KiB
Python

from dotenv import load_dotenv
from langchain_classic.retrievers import ContextualCompressionRetriever
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnableLambda, RunnablePassthrough
from langchain_google_community import VertexAISearchRetriever
from langchain_google_community.vertex_rank import VertexAIRank
from langchain_google_genai import ChatGoogleGenerativeAI
load_dotenv()
PROJECT = "akqa-ita-ai-poc1"
DATA_STORE = "akern-ds_1771234036654"
MODEL = "gemini-2.5-flash"
LOCATION = "eu"
PRINT_SOURCES = False
MAX_OUTPUT_TOKENS = 65535
class RagChain:
def __init__(
self,
question_template: str,
top_k: int,
top_p: float,
temperature: float,
retriever_max_docs: int,
reranker_max_results: int,
) -> None:
self.top_k = top_k
self.top_p = top_p
self.temperature = temperature
self.retriever_max_docs = retriever_max_docs
self.reranker_max_results = reranker_max_results
# with open("prompt.md") as f:
# question_template = f.read()
with open("question_rewrite_prompt.md") as f:
question_rewrite_template = f.read()
question_prompt = ChatPromptTemplate.from_template(question_template)
question_rewrite_prompt = ChatPromptTemplate.from_template(
question_rewrite_template
)
self._retriever_sources: list[dict] = []
self._reranked_sources: list[dict] = []
self._rephrased_question: str = ""
self._llm = ChatGoogleGenerativeAI(
model=MODEL,
project=PROJECT,
vertexai=True,
top_p=self.top_p,
top_k=self.top_k,
temperature=self.temperature,
max_output_tokens=MAX_OUTPUT_TOKENS,
)
self._base_retriever = VertexAISearchRetriever(
project_id=PROJECT,
data_store_id=DATA_STORE,
max_documents=self.retriever_max_docs,
location_id=LOCATION,
beta=True,
)
self._reranker = VertexAIRank(
project_id=PROJECT,
location_id=LOCATION,
ranking_config="default_ranking_config",
top_n=self.reranker_max_results,
)
self._compression_retriever = ContextualCompressionRetriever(
base_compressor=self._reranker, base_retriever=self._base_retriever
)
question_rewrite_chain = (
{"question": RunnablePassthrough()}
| question_rewrite_prompt
| self._llm
| StrOutputParser()
| RunnableLambda(self._log_rephrased_question)
)
rag_chain = (
{
"context": RunnableLambda(self._format_docs),
"question": RunnablePassthrough(),
}
| question_prompt
| self._llm
| StrOutputParser()
)
self._full_chain = question_rewrite_chain | rag_chain
def _log_rephrased_question(self, rephrased_question: str) -> str:
self._rephrased_question = rephrased_question
return rephrased_question
def _format_docs(self, question: str) -> str:
retrieved_docs = self._base_retriever.invoke(question)
reranked_docs = self._compression_retriever.invoke(question)
self._retriever_sources = [
{
"page_content": f"{doc.page_content[:50]}...",
"source": doc.metadata.get("source", ""),
}
for doc in retrieved_docs
]
# Build a lookup map from page_content -> source using the original
# retrieved docs, because VertexAIRank strips metadata (including source)
# from the documents it returns.
source_lookup: dict[str, str] = {
doc.page_content: doc.metadata.get("source", "") for doc in retrieved_docs
}
self._reranked_sources = [
{
"relevance_score": doc.metadata.get("relevance_score", ""),
"page_content": f"{doc.page_content[:50]}...",
"source": source_lookup.get(doc.page_content, ""),
}
for doc in reranked_docs
]
if PRINT_SOURCES:
print("========== RETRIEVER DOCUMENTS ==========")
for idx, doc in enumerate(retrieved_docs, start=1):
snippet = doc.page_content[:200].replace("\n", " ")
print(
f"[{idx}] metadata={doc.metadata['source']} | snippet=...{snippet}..."
)
print("========== RERANKED DOCUMENTS ==========")
for idx, doc in enumerate(reranked_docs, start=1):
snippet = doc.page_content[:200].replace("\n", " ")
source = source_lookup.get(doc.page_content, "")
print(
f"[{idx}] source={source} | relevance_score={doc.metadata.get('relevance_score', '')} | snippet=...{snippet}..."
)
return "\n\n".join(doc.page_content for doc in reranked_docs)
def getSources(self) -> list[dict]:
return list(self._retriever_sources)
def getRankedSources(self) -> list[dict]:
return list(self._reranked_sources)
def getRephrasedQuestion(self) -> str:
return self._rephrased_question
def stream(self, message: str):
return self._full_chain.astream(message)