146 lines
4.7 KiB
Python
146 lines
4.7 KiB
Python
from dotenv import load_dotenv
|
|
from langchain_classic.retrievers import ContextualCompressionRetriever
|
|
from langchain_core.output_parsers import StrOutputParser
|
|
from langchain_core.prompts import ChatPromptTemplate
|
|
from langchain_core.runnables import RunnableLambda, RunnablePassthrough
|
|
from langchain_google_community import VertexAISearchRetriever
|
|
from langchain_google_community.vertex_rank import VertexAIRank
|
|
from langchain_google_genai import ChatGoogleGenerativeAI
|
|
|
|
load_dotenv()
|
|
|
|
PROJECT = "akqa-ita-ai-poc1"
|
|
DATA_STORE = "akern-ds_1771234036654"
|
|
MODEL = "gemini-2.5-flash"
|
|
LOCATION = "eu"
|
|
PRINT_SOURCES = False
|
|
MAX_OUTPUT_TOKENS = 65535
|
|
|
|
|
|
class RagChain:
|
|
def __init__(
|
|
self,
|
|
top_k: int,
|
|
top_p: float,
|
|
temperature: float,
|
|
retriever_max_docs: int,
|
|
reranker_max_results: int,
|
|
) -> None:
|
|
self.top_k = top_k
|
|
self.top_p = top_p
|
|
self.temperature = temperature
|
|
self.retriever_max_docs = retriever_max_docs
|
|
self.reranker_max_results = reranker_max_results
|
|
|
|
with open("prompt.md") as f:
|
|
question_template = f.read()
|
|
|
|
with open("question_rewrite_prompt.md") as f:
|
|
question_rewrite_template = f.read()
|
|
|
|
question_prompt = ChatPromptTemplate.from_template(question_template)
|
|
question_rewrite_prompt = ChatPromptTemplate.from_template(
|
|
question_rewrite_template
|
|
)
|
|
|
|
self._retriever_sources: list[dict] = []
|
|
self._reranked_sources: list[dict] = []
|
|
|
|
self._llm = ChatGoogleGenerativeAI(
|
|
model=MODEL,
|
|
project=PROJECT,
|
|
vertexai=True,
|
|
top_p=self.top_p,
|
|
top_k=self.top_k,
|
|
temperature=self.temperature,
|
|
max_output_tokens=MAX_OUTPUT_TOKENS,
|
|
)
|
|
|
|
self._base_retriever = VertexAISearchRetriever(
|
|
project_id=PROJECT,
|
|
data_store_id=DATA_STORE,
|
|
max_documents=self.retriever_max_docs,
|
|
location_id=LOCATION,
|
|
beta=True,
|
|
)
|
|
|
|
self._reranker = VertexAIRank(
|
|
project_id=PROJECT,
|
|
location_id=LOCATION,
|
|
ranking_config="default_ranking_config",
|
|
top_n=self.reranker_max_results,
|
|
)
|
|
|
|
self._compression_retriever = ContextualCompressionRetriever(
|
|
base_compressor=self._reranker, base_retriever=self._base_retriever
|
|
)
|
|
|
|
question_rewrite_chain = (
|
|
{"question": RunnablePassthrough()}
|
|
| question_rewrite_prompt
|
|
| self._llm
|
|
| StrOutputParser()
|
|
| RunnableLambda(self._log_rewritten_question)
|
|
)
|
|
|
|
rag_chain = (
|
|
{
|
|
"context": RunnableLambda(self._format_docs),
|
|
"question": RunnablePassthrough(),
|
|
}
|
|
| question_prompt
|
|
| self._llm
|
|
| StrOutputParser()
|
|
)
|
|
|
|
self._full_chain = question_rewrite_chain | rag_chain
|
|
|
|
def _log_rewritten_question(self, rewritten_question: str) -> str:
|
|
return rewritten_question
|
|
|
|
def _format_docs(self, question: str) -> str:
|
|
retrieved_docs = self._base_retriever.invoke(question)
|
|
reranked_docs = self._compression_retriever.invoke(question)
|
|
|
|
self._retriever_sources = [
|
|
{
|
|
"page_content": f"{doc.page_content[:50]}...",
|
|
"source": doc.metadata.get("source", ""),
|
|
}
|
|
for doc in retrieved_docs
|
|
]
|
|
|
|
self._reranked_sources = [
|
|
{
|
|
"relevance_score": doc.metadata.get("relevance_score", ""),
|
|
"page_content": f"{doc.page_content[:50]}...",
|
|
}
|
|
for doc in reranked_docs
|
|
]
|
|
|
|
if PRINT_SOURCES:
|
|
print("========== RETRIEVER DOCUMENTS ==========")
|
|
for idx, doc in enumerate(retrieved_docs, start=1):
|
|
snippet = doc.page_content[:200].replace("\n", " ")
|
|
print(
|
|
f"[{idx}] metadata={doc.metadata['source']} | snippet=...{snippet}..."
|
|
)
|
|
|
|
print("========== RERANKED DOCUMENTS ==========")
|
|
for idx, doc in enumerate(reranked_docs, start=1):
|
|
snippet = doc.page_content[:200].replace("\n", " ")
|
|
print(
|
|
f"[{idx}] metadata={doc.metadata['relevance_score']} | snippet=...{snippet}..."
|
|
)
|
|
|
|
return "\n\n".join(doc.page_content for doc in reranked_docs)
|
|
|
|
def getSources(self) -> list[dict]:
|
|
return list(self._retriever_sources)
|
|
|
|
def getRankedSources(self) -> list[dict]:
|
|
return list(self._reranked_sources)
|
|
|
|
def stream(self, message: str):
|
|
return self._full_chain.astream(message)
|