100 lines
2.6 KiB
Python
100 lines
2.6 KiB
Python
import os
|
|
|
|
from dotenv import load_dotenv
|
|
from langchain_classic.retrievers import ContextualCompressionRetriever
|
|
from langchain_core.output_parsers import StrOutputParser
|
|
from langchain_core.prompts import ChatPromptTemplate
|
|
from langchain_core.runnables import RunnablePassthrough
|
|
from langchain_google_community import VertexAISearchRetriever
|
|
from langchain_google_community.vertex_rank import VertexAIRank
|
|
from langchain_google_genai import ChatGoogleGenerativeAI
|
|
|
|
load_dotenv()
|
|
|
|
PROJECT = "akqa-ita-ai-poc1"
|
|
DATA_STORE = "akern-ds_1771234036654"
|
|
MODEL = "gemini-2.5-flash"
|
|
LOCATION = "eu"
|
|
|
|
with open("prompt.md") as f:
|
|
template = f.read()
|
|
|
|
prompt = ChatPromptTemplate.from_template(template)
|
|
|
|
|
|
def format_docs(docs):
|
|
return "\n\n".join(doc.page_content for doc in docs)
|
|
|
|
|
|
llm = ChatGoogleGenerativeAI(
|
|
model=MODEL,
|
|
project=PROJECT,
|
|
vertexai=True,
|
|
top_p=0.95,
|
|
top_k=40,
|
|
temperature=0.0,
|
|
max_output_tokens=65535,
|
|
)
|
|
|
|
base_retriever = VertexAISearchRetriever(
|
|
project_id=PROJECT,
|
|
data_store_id=DATA_STORE,
|
|
max_documents=50,
|
|
location_id=LOCATION,
|
|
beta=True,
|
|
)
|
|
|
|
reranker = VertexAIRank(
|
|
project_id=PROJECT,
|
|
location_id="global",
|
|
ranking_config="default_ranking_config",
|
|
top_n=5,
|
|
)
|
|
|
|
compression_retriever = ContextualCompressionRetriever(
|
|
base_compressor=reranker, base_retriever=base_retriever
|
|
)
|
|
|
|
rag_chain = (
|
|
{"context": compression_retriever | format_docs, "question": RunnablePassthrough()}
|
|
| prompt
|
|
| llm
|
|
| StrOutputParser()
|
|
)
|
|
|
|
|
|
def answer_questions() -> None:
|
|
QUESTIONS_DIR = "domande"
|
|
|
|
if not os.path.exists(QUESTIONS_DIR):
|
|
print(f"Errore: la directory '{QUESTIONS_DIR}' non esiste.")
|
|
return
|
|
|
|
files = sorted([f for f in os.listdir(QUESTIONS_DIR) if f.endswith(".txt")])
|
|
|
|
for filename in files:
|
|
filepath = os.path.join(QUESTIONS_DIR, filename)
|
|
|
|
with open(filepath, "r", encoding="utf-8") as f:
|
|
question_content = f.read()
|
|
|
|
print(f"Elaborazione: {filename}...")
|
|
|
|
try:
|
|
response = rag_chain.invoke(question_content)
|
|
|
|
# Genera il nome del file di risposta (es. domanda1.txt -> risposta1.txt)
|
|
output_filename = filename.replace("domanda", "risposta")
|
|
|
|
with open(output_filename, "w", encoding="utf-8") as f:
|
|
f.write(response)
|
|
|
|
print(f"Risposta salvata in: {output_filename}")
|
|
except Exception as e:
|
|
print(f"Errore durante l'elaborazione di {filename}: {e}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
response = rag_chain.invoke("come si calcola il rapporto sodio potassio?")
|
|
print(response)
|