Files
AKERN-Langchain/main.py
2026-02-17 12:03:31 +01:00

124 lines
3.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from dotenv import load_dotenv
from langchain_classic.retrievers import ContextualCompressionRetriever
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnableLambda, RunnablePassthrough
from langchain_google_community import VertexAISearchRetriever
from langchain_google_community.vertex_rank import VertexAIRank
from langchain_google_genai import ChatGoogleGenerativeAI
load_dotenv()
PROJECT = "akqa-ita-ai-poc1"
DATA_STORE = "akern-ds_1771234036654"
MODEL = "gemini-2.5-flash"
LOCATION = "eu"
# LLM CONFIG
TOP_K = 40
TOP_P = 1
TEMPERATURE = 0.0
MAX_OUTPUT_TOKENS = 65535
RETRIEVER_MAX_DOCS = 50
RERANKER_MAX_RESULTS = 25
with open("prompt.md") as f:
template = f.read()
prompt = ChatPromptTemplate.from_template(template)
def format_docs(question: str) -> str:
retrieved_docs = base_retriever.invoke(question)
print("========== RETRIEVER DOCUMENTS ==========")
for idx, doc in enumerate(retrieved_docs, start=1):
snippet = doc.page_content[:200].replace("\n", " ")
print(f"[{idx}] metadata={doc.metadata['source']} | snippet=...{snippet}...")
reranked_docs = compression_retriever.invoke(question)
print("========== RERANKED DOCUMENTS ==========")
for idx, doc in enumerate(reranked_docs, start=1):
snippet = doc.page_content[:200].replace("\n", " ")
print(
f"[{idx}] metadata={doc.metadata['relevance_score']} | snippet=...{snippet}..."
)
return "\n\n".join(doc.page_content for doc in reranked_docs)
llm = ChatGoogleGenerativeAI(
model=MODEL,
project=PROJECT,
vertexai=True,
top_p=TOP_P,
top_k=TOP_K,
temperature=TEMPERATURE,
max_output_tokens=MAX_OUTPUT_TOKENS,
)
base_retriever = VertexAISearchRetriever(
project_id=PROJECT,
data_store_id=DATA_STORE,
max_documents=RETRIEVER_MAX_DOCS,
location_id=LOCATION,
beta=True,
)
reranker = VertexAIRank(
project_id=PROJECT,
location_id="eu",
ranking_config="default_ranking_config",
top_n=RERANKER_MAX_RESULTS,
)
compression_retriever = ContextualCompressionRetriever(
base_compressor=reranker, base_retriever=base_retriever
)
rag_chain = (
{"context": RunnableLambda(format_docs), "question": RunnablePassthrough()}
| prompt
| llm
| StrOutputParser()
)
# def parse_questions():
# domande_dir = "domande"
# risposte_dir = "risposte"
# os.makedirs(risposte_dir, exist_ok=True)
# for filename in sorted(os.listdir(domande_dir)):
# if not filename.lower().endswith(".txt"):
# continue
# domanda_path = os.path.join(domande_dir, filename)
# with open(domanda_path, "r", encoding="utf-8") as f:
# contents = f.read()
# print(f"========== DOMANDA ({domanda_path}) ==========")
# print(contents)
# response = rag_chain.invoke(contents)
# print("========== RISPOSTA ==========")
# print(response)
# print("\n\n")
# base_name = os.path.splitext(filename)[0]
# suffix = "".join(ch for ch in base_name if ch.isdigit()) or base_name
# risposta_path = os.path.join(risposte_dir, f"risposta{suffix}.txt")
# with open(risposta_path, "w", encoding="utf-8") as f:
# f.write(response)
if __name__ == "__main__":
response = rag_chain.invoke(
"""Buongiorno, non so se è la mail specifica ma volevo se possibile dei chiarimenti per linterpretazione dei parametri BCM /SMM/ASMM. Mi capita a volte di trovare casi in cui la BCM è aumentata ma allo stesso tempo SMM/ASMM hanno subito una piccola flessione in negativo (o viceversa). Se la parte metabolicamente attiva aumenta perchè può succedere che gli altri compartimenti si riducono?? E allo stesso tempo phA e BCM possono essere inversamente proporzionali?? So che il phA correla con massa e struttura + idratazione."""
)
print(response)