import asyncio from dotenv import load_dotenv from langchain_classic.retrievers import ContextualCompressionRetriever from langchain_core.output_parsers import StrOutputParser from langchain_core.prompts import ChatPromptTemplate from langchain_core.runnables import RunnableLambda, RunnablePassthrough from langchain_google_community import VertexAISearchRetriever from langchain_google_community.vertex_rank import VertexAIRank from langchain_google_genai import ChatGoogleGenerativeAI load_dotenv() PROJECT = "akqa-ita-ai-poc1" DATA_STORE = "akern-ds_1771234036654" MODEL = "gemini-2.5-flash" LOCATION = "eu" PRINT_SOURCES = False # LLM CONFIG TOP_K = 40 TOP_P = 1 TEMPERATURE = 0.0 MAX_OUTPUT_TOKENS = 65535 RETRIEVER_MAX_DOCS = 50 RERANKER_MAX_RESULTS = 25 with open("prompt.md") as f: template = f.read() prompt = ChatPromptTemplate.from_template(template) def format_docs(question: str) -> str: retrieved_docs = base_retriever.invoke(question) reranked_docs = compression_retriever.invoke(question) if PRINT_SOURCES: print("========== RETRIEVER DOCUMENTS ==========") for idx, doc in enumerate(retrieved_docs, start=1): snippet = doc.page_content[:200].replace("\n", " ") print( f"[{idx}] metadata={doc.metadata['source']} | snippet=...{snippet}..." ) print("========== RERANKED DOCUMENTS ==========") for idx, doc in enumerate(reranked_docs, start=1): snippet = doc.page_content[:200].replace("\n", " ") print( f"[{idx}] metadata={doc.metadata['relevance_score']} | snippet=...{snippet}..." ) return "\n\n".join(doc.page_content for doc in reranked_docs) llm = ChatGoogleGenerativeAI( model=MODEL, project=PROJECT, vertexai=True, top_p=TOP_P, top_k=TOP_K, temperature=TEMPERATURE, max_output_tokens=MAX_OUTPUT_TOKENS, ) base_retriever = VertexAISearchRetriever( project_id=PROJECT, data_store_id=DATA_STORE, max_documents=RETRIEVER_MAX_DOCS, location_id=LOCATION, beta=True, ) reranker = VertexAIRank( project_id=PROJECT, location_id="eu", ranking_config="default_ranking_config", top_n=RERANKER_MAX_RESULTS, ) compression_retriever = ContextualCompressionRetriever( base_compressor=reranker, base_retriever=base_retriever ) rag_chain = ( {"context": RunnableLambda(format_docs), "question": RunnablePassthrough()} | prompt | llm | StrOutputParser() ) async def async_invoke(rag_chain, prompt: str): return rag_chain.invoke(prompt) async def main(): ( res1, res2, ) = await asyncio.gather( async_invoke(rag_chain, "come si calcola l'angolo di fase?"), async_invoke(rag_chain, "cos'e' la massa magra?"), ) print("RES1") print(res1) print("\n\nRES2") print(res2) if __name__ == "__main__": asyncio.run(main())