114 lines
2.9 KiB
Python
114 lines
2.9 KiB
Python
import asyncio
|
|
|
|
from dotenv import load_dotenv
|
|
from langchain_classic.retrievers import ContextualCompressionRetriever
|
|
from langchain_core.output_parsers import StrOutputParser
|
|
from langchain_core.prompts import ChatPromptTemplate
|
|
from langchain_core.runnables import RunnableLambda, RunnablePassthrough
|
|
from langchain_google_community import VertexAISearchRetriever
|
|
from langchain_google_community.vertex_rank import VertexAIRank
|
|
from langchain_google_genai import ChatGoogleGenerativeAI
|
|
|
|
load_dotenv()
|
|
|
|
PROJECT = "akqa-ita-ai-poc1"
|
|
DATA_STORE = "akern-ds_1771234036654"
|
|
MODEL = "gemini-2.5-flash"
|
|
LOCATION = "eu"
|
|
PRINT_SOURCES = False
|
|
|
|
# LLM CONFIG
|
|
TOP_K = 40
|
|
TOP_P = 1
|
|
TEMPERATURE = 0.0
|
|
MAX_OUTPUT_TOKENS = 65535
|
|
RETRIEVER_MAX_DOCS = 50
|
|
RERANKER_MAX_RESULTS = 25
|
|
|
|
with open("prompt.md") as f:
|
|
template = f.read()
|
|
|
|
prompt = ChatPromptTemplate.from_template(template)
|
|
|
|
|
|
def format_docs(question: str) -> str:
|
|
retrieved_docs = base_retriever.invoke(question)
|
|
reranked_docs = compression_retriever.invoke(question)
|
|
|
|
if PRINT_SOURCES:
|
|
print("========== RETRIEVER DOCUMENTS ==========")
|
|
for idx, doc in enumerate(retrieved_docs, start=1):
|
|
snippet = doc.page_content[:200].replace("\n", " ")
|
|
print(
|
|
f"[{idx}] metadata={doc.metadata['source']} | snippet=...{snippet}..."
|
|
)
|
|
|
|
print("========== RERANKED DOCUMENTS ==========")
|
|
for idx, doc in enumerate(reranked_docs, start=1):
|
|
snippet = doc.page_content[:200].replace("\n", " ")
|
|
print(
|
|
f"[{idx}] metadata={doc.metadata['relevance_score']} | snippet=...{snippet}..."
|
|
)
|
|
|
|
return "\n\n".join(doc.page_content for doc in reranked_docs)
|
|
|
|
|
|
llm = ChatGoogleGenerativeAI(
|
|
model=MODEL,
|
|
project=PROJECT,
|
|
vertexai=True,
|
|
top_p=TOP_P,
|
|
top_k=TOP_K,
|
|
temperature=TEMPERATURE,
|
|
max_output_tokens=MAX_OUTPUT_TOKENS,
|
|
)
|
|
|
|
base_retriever = VertexAISearchRetriever(
|
|
project_id=PROJECT,
|
|
data_store_id=DATA_STORE,
|
|
max_documents=RETRIEVER_MAX_DOCS,
|
|
location_id=LOCATION,
|
|
beta=True,
|
|
)
|
|
|
|
reranker = VertexAIRank(
|
|
project_id=PROJECT,
|
|
location_id="eu",
|
|
ranking_config="default_ranking_config",
|
|
top_n=RERANKER_MAX_RESULTS,
|
|
)
|
|
|
|
compression_retriever = ContextualCompressionRetriever(
|
|
base_compressor=reranker, base_retriever=base_retriever
|
|
)
|
|
|
|
rag_chain = (
|
|
{"context": RunnableLambda(format_docs), "question": RunnablePassthrough()}
|
|
| prompt
|
|
| llm
|
|
| StrOutputParser()
|
|
)
|
|
|
|
|
|
async def async_invoke(rag_chain, prompt: str):
|
|
return rag_chain.invoke(prompt)
|
|
|
|
|
|
async def main():
|
|
(
|
|
res1,
|
|
res2,
|
|
) = await asyncio.gather(
|
|
async_invoke(rag_chain, "come si calcola l'angolo di fase?"),
|
|
async_invoke(rag_chain, "cos'e' la massa magra?"),
|
|
)
|
|
|
|
print("RES1")
|
|
print(res1)
|
|
print("\n\nRES2")
|
|
print(res2)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.run(main())
|