From e1afb6e6c77973d15b929a2bfc63dd59c97cf584 Mon Sep 17 00:00:00 2001 From: Matteo Rosati Date: Wed, 18 Feb 2026 14:21:31 +0100 Subject: [PATCH] message streaming --- app.py | 41 +++++++-- chain.py | 194 ++++++++++++++++++++++++------------------- static/css/chat.css | 59 +++++++++++++ static/js/chat.js | 87 ++++++++++++++++++- templates/index.html | 19 +++-- 5 files changed, 299 insertions(+), 101 deletions(-) diff --git a/app.py b/app.py index 941afff..e4c988a 100644 --- a/app.py +++ b/app.py @@ -1,9 +1,12 @@ +import json + from fastapi import FastAPI, WebSocket from fastapi.templating import Jinja2Templates from fastapi import Request from fastapi.staticfiles import StaticFiles -from chain import full_chain +from chain import RagChain +from pprint import pprint app = FastAPI() @@ -22,10 +25,36 @@ async def websocket_endpoint(websocket: WebSocket): try: while True: - message = await websocket.receive_text() - await websocket.send_json({"type": "start"}) - async for chunk in full_chain.astream(message): - await websocket.send_json({"type": "chunk", "data": chunk}) - await websocket.send_json({"type": "end"}) + raw_message = await websocket.receive_text() + try: + await websocket.send_json({"type": "start"}) + try: + payload = json.loads(raw_message) + message = payload.get("message", "") + config = payload.get("config", {}) + except json.JSONDecodeError: + message = raw_message + config = {} + + rag_chain = RagChain( + top_k=int(config.get("top_k", 40)), + top_p=float(config.get("top_p", 0.0)), + temperature=float(config.get("temperature", 0.0)), + retriever_max_docs=int(config.get("retriever_max_docs", 40)), + reranker_max_results=int(config.get("reranker_max_results", 20)), + ) + + async for chunk in rag_chain.stream(message): + await websocket.send_json({"type": "chunk", "data": chunk}) + + await websocket.send_json( + { + "type": "report", + "sources": rag_chain.getSources(), + "reranked_sources": rag_chain.getRankedSources(), + } + ) + finally: + await websocket.send_json({"type": "end"}) except Exception: pass diff --git a/chain.py b/chain.py index 073542c..f175878 100644 --- a/chain.py +++ b/chain.py @@ -1,5 +1,3 @@ -import asyncio - from dotenv import load_dotenv from langchain_classic.retrievers import ContextualCompressionRetriever from langchain_core.output_parsers import StrOutputParser @@ -16,106 +14,132 @@ DATA_STORE = "akern-ds_1771234036654" MODEL = "gemini-2.5-flash" LOCATION = "eu" PRINT_SOURCES = False - -# LLM CONFIG -TOP_K = 40 -TOP_P = 1 -TEMPERATURE = 0.0 MAX_OUTPUT_TOKENS = 65535 -RETRIEVER_MAX_DOCS = 50 -RERANKER_MAX_RESULTS = 25 - -with open("prompt.md") as f: - question_template = f.read() - -with open("question_rewrite_prompt.md") as f: - question_rewrite_template = f.read() - -question_prompt = ChatPromptTemplate.from_template(question_template) -question_rewrite_prompt = ChatPromptTemplate.from_template(question_rewrite_template) -def format_docs(question: str) -> str: - retrieved_docs = base_retriever.invoke(question) - reranked_docs = compression_retriever.invoke(question) +class RagChain: + def __init__( + self, + top_k: int, + top_p: float, + temperature: float, + retriever_max_docs: int, + reranker_max_results: int, + ) -> None: + self.top_k = top_k + self.top_p = top_p + self.temperature = temperature + self.retriever_max_docs = retriever_max_docs + self.reranker_max_results = reranker_max_results - if PRINT_SOURCES: - print("========== RETRIEVER DOCUMENTS ==========") - for idx, doc in enumerate(retrieved_docs, start=1): - snippet = doc.page_content[:200].replace("\n", " ") - print( - f"[{idx}] metadata={doc.metadata['source']} | snippet=...{snippet}..." - ) + with open("prompt.md") as f: + question_template = f.read() - print("========== RERANKED DOCUMENTS ==========") - for idx, doc in enumerate(reranked_docs, start=1): - snippet = doc.page_content[:200].replace("\n", " ") - print( - f"[{idx}] metadata={doc.metadata['relevance_score']} | snippet=...{snippet}..." - ) + with open("question_rewrite_prompt.md") as f: + question_rewrite_template = f.read() - return "\n\n".join(doc.page_content for doc in reranked_docs) + question_prompt = ChatPromptTemplate.from_template(question_template) + question_rewrite_prompt = ChatPromptTemplate.from_template( + question_rewrite_template + ) + self._retriever_sources: list[dict] = [] + self._reranked_sources: list[dict] = [] -def log_rewritten_question(rewritten_question: str) -> str: - print("=== REWRITTEN QUESTION ===") - print(rewritten_question) - return rewritten_question + self._llm = ChatGoogleGenerativeAI( + model=MODEL, + project=PROJECT, + vertexai=True, + top_p=self.top_p, + top_k=self.top_k, + temperature=self.temperature, + max_output_tokens=MAX_OUTPUT_TOKENS, + ) + self._base_retriever = VertexAISearchRetriever( + project_id=PROJECT, + data_store_id=DATA_STORE, + max_documents=self.retriever_max_docs, + location_id=LOCATION, + beta=True, + ) -llm = ChatGoogleGenerativeAI( - model=MODEL, - project=PROJECT, - vertexai=True, - top_p=TOP_P, - top_k=TOP_K, - temperature=TEMPERATURE, - max_output_tokens=MAX_OUTPUT_TOKENS, -) + self._reranker = VertexAIRank( + project_id=PROJECT, + location_id=LOCATION, + ranking_config="default_ranking_config", + top_n=self.reranker_max_results, + ) -base_retriever = VertexAISearchRetriever( - project_id=PROJECT, - data_store_id=DATA_STORE, - max_documents=RETRIEVER_MAX_DOCS, - location_id=LOCATION, - beta=True, -) + self._compression_retriever = ContextualCompressionRetriever( + base_compressor=self._reranker, base_retriever=self._base_retriever + ) -reranker = VertexAIRank( - project_id=PROJECT, - location_id="eu", - ranking_config="default_ranking_config", - top_n=RERANKER_MAX_RESULTS, -) + question_rewrite_chain = ( + {"question": RunnablePassthrough()} + | question_rewrite_prompt + | self._llm + | StrOutputParser() + | RunnableLambda(self._log_rewritten_question) + ) -compression_retriever = ContextualCompressionRetriever( - base_compressor=reranker, base_retriever=base_retriever -) + rag_chain = ( + { + "context": RunnableLambda(self._format_docs), + "question": RunnablePassthrough(), + } + | question_prompt + | self._llm + | StrOutputParser() + ) -question_rewrite_chain = ( - {"question": RunnablePassthrough()} - | question_rewrite_prompt - | llm - | StrOutputParser() - | RunnableLambda(log_rewritten_question) -) + self._full_chain = question_rewrite_chain | rag_chain -rag_chain = ( - {"context": RunnableLambda(format_docs), "question": RunnablePassthrough()} - | question_prompt - | llm - | StrOutputParser() -) + def _log_rewritten_question(self, rewritten_question: str) -> str: + return rewritten_question -full_chain = question_rewrite_chain | rag_chain + def _format_docs(self, question: str) -> str: + retrieved_docs = self._base_retriever.invoke(question) + reranked_docs = self._compression_retriever.invoke(question) + self._retriever_sources = [ + { + "page_content": f"{doc.page_content[:50]}...", + "source": doc.metadata.get("source", ""), + } + for doc in retrieved_docs + ] -async def main(): - response = await full_chain.ainvoke( - "Buongiorno, non so se è la mail specifica ma volevo se possibile dei chiarimenti per l’interpretazione dei parametri BCM /SMM/ASMM. Mi capita a volte di trovare casi in cui la BCM è aumentata ma allo stesso tempo SMM/ASMM hanno subito una piccola flessione in negativo (o viceversa). Se la parte metabolicamente attiva aumenta perchè può succedere che gli altri compartimenti si riducono?? E allo stesso tempo phA e BCM possono essere inversamente proporzionali?? So che il phA correla con massa e struttura + idratazione." - ) - print(response) + self._reranked_sources = [ + { + "relevance_score": doc.metadata.get("relevance_score", ""), + "page_content": f"{doc.page_content[:50]}...", + } + for doc in reranked_docs + ] + if PRINT_SOURCES: + print("========== RETRIEVER DOCUMENTS ==========") + for idx, doc in enumerate(retrieved_docs, start=1): + snippet = doc.page_content[:200].replace("\n", " ") + print( + f"[{idx}] metadata={doc.metadata['source']} | snippet=...{snippet}..." + ) -if __name__ == "__main__": - asyncio.run(main()) + print("========== RERANKED DOCUMENTS ==========") + for idx, doc in enumerate(reranked_docs, start=1): + snippet = doc.page_content[:200].replace("\n", " ") + print( + f"[{idx}] metadata={doc.metadata['relevance_score']} | snippet=...{snippet}..." + ) + + return "\n\n".join(doc.page_content for doc in reranked_docs) + + def getSources(self) -> list[dict]: + return list(self._retriever_sources) + + def getRankedSources(self) -> list[dict]: + return list(self._reranked_sources) + + def stream(self, message: str): + return self._full_chain.astream(message) diff --git a/static/css/chat.css b/static/css/chat.css index 6e4dc1a..b8e9efa 100644 --- a/static/css/chat.css +++ b/static/css/chat.css @@ -94,6 +94,49 @@ body { color: #111827; } +/* Typing / loading indicator */ +.message--loading { + display: flex; + align-items: center; + gap: 5px; + padding: 12px 16px; +} + +.typing-dot { + width: 8px; + height: 8px; + border-radius: 50%; + background: #9ca3af; + animation: typing-bounce 1.2s infinite ease-in-out; +} + +.typing-dot:nth-child(1) { + animation-delay: 0s; +} + +.typing-dot:nth-child(2) { + animation-delay: 0.2s; +} + +.typing-dot:nth-child(3) { + animation-delay: 0.4s; +} + +@keyframes typing-bounce { + + 0%, + 60%, + 100% { + transform: translateY(0); + background: #9ca3af; + } + + 30% { + transform: translateY(-6px); + background: #6b7280; + } +} + .chat__footer { padding: 16px 20px; border-top: 1px solid #e5e7eb; @@ -125,6 +168,22 @@ body { cursor: pointer; } +.chat__button:disabled { + background: #cbd5f5; + color: #ffffff; + cursor: not-allowed; + opacity: 0.75; +} + +.chat__button--secondary { + background: #e5e7eb; + color: #1f2937; +} + +.chat__button--secondary:hover { + background: #d1d5db; +} + .chat__status { padding: 8px 20px 0; font-size: 12px; diff --git a/static/js/chat.js b/static/js/chat.js index 7cde62d..0e7f754 100644 --- a/static/js/chat.js +++ b/static/js/chat.js @@ -2,17 +2,33 @@ const statusEl = document.getElementById("status"); const messagesEl = document.getElementById("messages"); const formEl = document.getElementById("chat-form"); const inputEl = document.getElementById("chat-input"); +const clearEl = document.getElementById("chat-clear"); +const sendButtonEl = formEl?.querySelector("button[type='submit']"); const settingsToggle = document.getElementById("settings-toggle"); const settingsDrawer = document.getElementById("settings-drawer"); const settingsClose = document.getElementById("settings-close"); const rangeInputs = settingsDrawer ? Array.from(settingsDrawer.querySelectorAll("input[type='range']")) : []; +const configInputs = { + topK: document.getElementById("config-top-k"), + topP: document.getElementById("config-top-p"), + temperature: document.getElementById("config-temperature"), + retrieverMaxDocs: document.getElementById("config-retriever-max-docs"), + rerankerMaxDocs: document.getElementById("config-reranker-max-docs"), +}; const scheme = window.location.protocol === "https:" ? "wss" : "ws"; const socketUrl = `${scheme}://${window.location.host}/ws`; let socket; let streamingBubble = null; +let loadingBubble = null; +let isAwaitingResponse = false; + +const updateSendButtonState = () => { + if (!sendButtonEl) return; + sendButtonEl.disabled = isAwaitingResponse; +}; const addMessage = (text, direction) => { const bubble = document.createElement("div"); @@ -22,6 +38,29 @@ const addMessage = (text, direction) => { messagesEl.scrollTop = messagesEl.scrollHeight; }; +const showLoadingBubble = () => { + if (loadingBubble) return; + loadingBubble = document.createElement("div"); + loadingBubble.className = "message message--in message--loading"; + loadingBubble.innerHTML = + ''; + messagesEl.appendChild(loadingBubble); + messagesEl.scrollTop = messagesEl.scrollHeight; +}; + +const removeLoadingBubble = () => { + if (!loadingBubble) return; + loadingBubble.remove(); + loadingBubble = null; +}; + +const clearMessages = () => { + if (!messagesEl) return; + messagesEl.innerHTML = ""; + streamingBubble = null; + loadingBubble = null; +}; + const connect = () => { socket = new WebSocket(socketUrl); @@ -39,11 +78,11 @@ const connect = () => { } if (payload.type === "start") { + // Keep the loading bubble visible until the first chunk arrives. + // Just prepare the streaming bubble but don't append it yet. streamingBubble = document.createElement("div"); streamingBubble.className = "message message--in"; streamingBubble.textContent = ""; - messagesEl.appendChild(streamingBubble); - messagesEl.scrollTop = messagesEl.scrollHeight; return; } @@ -51,6 +90,10 @@ const connect = () => { if (!streamingBubble) { streamingBubble = document.createElement("div"); streamingBubble.className = "message message--in"; + } + // First chunk: swap loading bubble for the streaming bubble. + if (!streamingBubble.isConnected) { + removeLoadingBubble(); messagesEl.appendChild(streamingBubble); } streamingBubble.textContent += payload.data ?? ""; @@ -58,18 +101,36 @@ const connect = () => { return; } + if (payload.type === "report") { + console.log("%cSOURCES", 'border: 2px solid red; padding: 1em; color: red; font-size: 14px; font-weight: bold;') + console.log(payload.sources); + console.log("%cRE-RANKED SOURCES", 'border: 2px solid red; padding: 1em; color: red; font-size: 14px; font-weight: bold;') + console.log(payload.reranked_sources); + return; + } + if (payload.type === "end") { + // Safety net: remove loading bubble if no chunks were ever received. + removeLoadingBubble(); streamingBubble = null; + isAwaitingResponse = false; + updateSendButtonState(); return; } }); socket.addEventListener("close", () => { statusEl.textContent = "Disconnected"; + removeLoadingBubble(); + isAwaitingResponse = false; + updateSendButtonState(); }); socket.addEventListener("error", () => { statusEl.textContent = "Connection error"; + removeLoadingBubble(); + isAwaitingResponse = false; + updateSendButtonState(); }); }; @@ -106,6 +167,14 @@ settingsToggle?.addEventListener("click", () => { settingsClose?.addEventListener("click", closeDrawer); +clearEl?.addEventListener("click", () => { + clearMessages(); + inputEl?.focus(); + updateSendButtonState(); +}); + +inputEl?.addEventListener("input", updateSendButtonState); + formEl.addEventListener("submit", (event) => { event.preventDefault(); const text = inputEl.value.trim(); @@ -114,10 +183,22 @@ formEl.addEventListener("submit", (event) => { addMessage("Not connected.", "in"); return; } + if (isAwaitingResponse) return; + const config = { + top_k: Number(configInputs.topK?.value ?? 0), + top_p: Number(configInputs.topP?.value ?? 0), + temperature: Number(configInputs.temperature?.value ?? 0), + retriever_max_docs: Number(configInputs.retrieverMaxDocs?.value ?? 0), + reranker_max_results: Number(configInputs.rerankerMaxDocs?.value ?? 0), + }; addMessage(text, "out"); - socket.send(text); + socket.send(JSON.stringify({ message: text, config })); inputEl.value = ""; inputEl.focus(); + isAwaitingResponse = true; + updateSendButtonState(); + showLoadingBubble(); }); connect(); +updateSendButtonState(); diff --git a/templates/index.html b/templates/index.html index 6b319fd..c6477f7 100644 --- a/templates/index.html +++ b/templates/index.html @@ -4,7 +4,7 @@ - Chat + AKERN Assistant @@ -12,7 +12,7 @@
- WebSocket Echo Chat + AKERN Assistant
@@ -22,6 +22,7 @@ +
@@ -33,27 +34,31 @@

Solo frontend, nessuna logica applicata.