message streaming

This commit is contained in:
Matteo Rosati
2026-02-18 14:21:31 +01:00
parent 3e6fefabbd
commit e1afb6e6c7
5 changed files with 299 additions and 101 deletions

35
app.py
View File

@@ -1,9 +1,12 @@
import json
from fastapi import FastAPI, WebSocket from fastapi import FastAPI, WebSocket
from fastapi.templating import Jinja2Templates from fastapi.templating import Jinja2Templates
from fastapi import Request from fastapi import Request
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from chain import full_chain from chain import RagChain
from pprint import pprint
app = FastAPI() app = FastAPI()
@@ -22,10 +25,36 @@ async def websocket_endpoint(websocket: WebSocket):
try: try:
while True: while True:
message = await websocket.receive_text() raw_message = await websocket.receive_text()
try:
await websocket.send_json({"type": "start"}) await websocket.send_json({"type": "start"})
async for chunk in full_chain.astream(message): try:
payload = json.loads(raw_message)
message = payload.get("message", "")
config = payload.get("config", {})
except json.JSONDecodeError:
message = raw_message
config = {}
rag_chain = RagChain(
top_k=int(config.get("top_k", 40)),
top_p=float(config.get("top_p", 0.0)),
temperature=float(config.get("temperature", 0.0)),
retriever_max_docs=int(config.get("retriever_max_docs", 40)),
reranker_max_results=int(config.get("reranker_max_results", 20)),
)
async for chunk in rag_chain.stream(message):
await websocket.send_json({"type": "chunk", "data": chunk}) await websocket.send_json({"type": "chunk", "data": chunk})
await websocket.send_json(
{
"type": "report",
"sources": rag_chain.getSources(),
"reranked_sources": rag_chain.getRankedSources(),
}
)
finally:
await websocket.send_json({"type": "end"}) await websocket.send_json({"type": "end"})
except Exception: except Exception:
pass pass

172
chain.py
View File

@@ -1,5 +1,3 @@
import asyncio
from dotenv import load_dotenv from dotenv import load_dotenv
from langchain_classic.retrievers import ContextualCompressionRetriever from langchain_classic.retrievers import ContextualCompressionRetriever
from langchain_core.output_parsers import StrOutputParser from langchain_core.output_parsers import StrOutputParser
@@ -16,14 +14,23 @@ DATA_STORE = "akern-ds_1771234036654"
MODEL = "gemini-2.5-flash" MODEL = "gemini-2.5-flash"
LOCATION = "eu" LOCATION = "eu"
PRINT_SOURCES = False PRINT_SOURCES = False
# LLM CONFIG
TOP_K = 40
TOP_P = 1
TEMPERATURE = 0.0
MAX_OUTPUT_TOKENS = 65535 MAX_OUTPUT_TOKENS = 65535
RETRIEVER_MAX_DOCS = 50
RERANKER_MAX_RESULTS = 25
class RagChain:
def __init__(
self,
top_k: int,
top_p: float,
temperature: float,
retriever_max_docs: int,
reranker_max_results: int,
) -> None:
self.top_k = top_k
self.top_p = top_p
self.temperature = temperature
self.retriever_max_docs = retriever_max_docs
self.reranker_max_results = reranker_max_results
with open("prompt.md") as f: with open("prompt.md") as f:
question_template = f.read() question_template = f.read()
@@ -32,12 +39,84 @@ with open("question_rewrite_prompt.md") as f:
question_rewrite_template = f.read() question_rewrite_template = f.read()
question_prompt = ChatPromptTemplate.from_template(question_template) question_prompt = ChatPromptTemplate.from_template(question_template)
question_rewrite_prompt = ChatPromptTemplate.from_template(question_rewrite_template) question_rewrite_prompt = ChatPromptTemplate.from_template(
question_rewrite_template
)
self._retriever_sources: list[dict] = []
self._reranked_sources: list[dict] = []
def format_docs(question: str) -> str: self._llm = ChatGoogleGenerativeAI(
retrieved_docs = base_retriever.invoke(question) model=MODEL,
reranked_docs = compression_retriever.invoke(question) project=PROJECT,
vertexai=True,
top_p=self.top_p,
top_k=self.top_k,
temperature=self.temperature,
max_output_tokens=MAX_OUTPUT_TOKENS,
)
self._base_retriever = VertexAISearchRetriever(
project_id=PROJECT,
data_store_id=DATA_STORE,
max_documents=self.retriever_max_docs,
location_id=LOCATION,
beta=True,
)
self._reranker = VertexAIRank(
project_id=PROJECT,
location_id=LOCATION,
ranking_config="default_ranking_config",
top_n=self.reranker_max_results,
)
self._compression_retriever = ContextualCompressionRetriever(
base_compressor=self._reranker, base_retriever=self._base_retriever
)
question_rewrite_chain = (
{"question": RunnablePassthrough()}
| question_rewrite_prompt
| self._llm
| StrOutputParser()
| RunnableLambda(self._log_rewritten_question)
)
rag_chain = (
{
"context": RunnableLambda(self._format_docs),
"question": RunnablePassthrough(),
}
| question_prompt
| self._llm
| StrOutputParser()
)
self._full_chain = question_rewrite_chain | rag_chain
def _log_rewritten_question(self, rewritten_question: str) -> str:
return rewritten_question
def _format_docs(self, question: str) -> str:
retrieved_docs = self._base_retriever.invoke(question)
reranked_docs = self._compression_retriever.invoke(question)
self._retriever_sources = [
{
"page_content": f"{doc.page_content[:50]}...",
"source": doc.metadata.get("source", ""),
}
for doc in retrieved_docs
]
self._reranked_sources = [
{
"relevance_score": doc.metadata.get("relevance_score", ""),
"page_content": f"{doc.page_content[:50]}...",
}
for doc in reranked_docs
]
if PRINT_SOURCES: if PRINT_SOURCES:
print("========== RETRIEVER DOCUMENTS ==========") print("========== RETRIEVER DOCUMENTS ==========")
@@ -56,66 +135,11 @@ def format_docs(question: str) -> str:
return "\n\n".join(doc.page_content for doc in reranked_docs) return "\n\n".join(doc.page_content for doc in reranked_docs)
def getSources(self) -> list[dict]:
return list(self._retriever_sources)
def log_rewritten_question(rewritten_question: str) -> str: def getRankedSources(self) -> list[dict]:
print("=== REWRITTEN QUESTION ===") return list(self._reranked_sources)
print(rewritten_question)
return rewritten_question
def stream(self, message: str):
llm = ChatGoogleGenerativeAI( return self._full_chain.astream(message)
model=MODEL,
project=PROJECT,
vertexai=True,
top_p=TOP_P,
top_k=TOP_K,
temperature=TEMPERATURE,
max_output_tokens=MAX_OUTPUT_TOKENS,
)
base_retriever = VertexAISearchRetriever(
project_id=PROJECT,
data_store_id=DATA_STORE,
max_documents=RETRIEVER_MAX_DOCS,
location_id=LOCATION,
beta=True,
)
reranker = VertexAIRank(
project_id=PROJECT,
location_id="eu",
ranking_config="default_ranking_config",
top_n=RERANKER_MAX_RESULTS,
)
compression_retriever = ContextualCompressionRetriever(
base_compressor=reranker, base_retriever=base_retriever
)
question_rewrite_chain = (
{"question": RunnablePassthrough()}
| question_rewrite_prompt
| llm
| StrOutputParser()
| RunnableLambda(log_rewritten_question)
)
rag_chain = (
{"context": RunnableLambda(format_docs), "question": RunnablePassthrough()}
| question_prompt
| llm
| StrOutputParser()
)
full_chain = question_rewrite_chain | rag_chain
async def main():
response = await full_chain.ainvoke(
"Buongiorno, non so se è la mail specifica ma volevo se possibile dei chiarimenti per linterpretazione dei parametri BCM /SMM/ASMM. Mi capita a volte di trovare casi in cui la BCM è aumentata ma allo stesso tempo SMM/ASMM hanno subito una piccola flessione in negativo (o viceversa). Se la parte metabolicamente attiva aumenta perchè può succedere che gli altri compartimenti si riducono?? E allo stesso tempo phA e BCM possono essere inversamente proporzionali?? So che il phA correla con massa e struttura + idratazione."
)
print(response)
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -94,6 +94,49 @@ body {
color: #111827; color: #111827;
} }
/* Typing / loading indicator */
.message--loading {
display: flex;
align-items: center;
gap: 5px;
padding: 12px 16px;
}
.typing-dot {
width: 8px;
height: 8px;
border-radius: 50%;
background: #9ca3af;
animation: typing-bounce 1.2s infinite ease-in-out;
}
.typing-dot:nth-child(1) {
animation-delay: 0s;
}
.typing-dot:nth-child(2) {
animation-delay: 0.2s;
}
.typing-dot:nth-child(3) {
animation-delay: 0.4s;
}
@keyframes typing-bounce {
0%,
60%,
100% {
transform: translateY(0);
background: #9ca3af;
}
30% {
transform: translateY(-6px);
background: #6b7280;
}
}
.chat__footer { .chat__footer {
padding: 16px 20px; padding: 16px 20px;
border-top: 1px solid #e5e7eb; border-top: 1px solid #e5e7eb;
@@ -125,6 +168,22 @@ body {
cursor: pointer; cursor: pointer;
} }
.chat__button:disabled {
background: #cbd5f5;
color: #ffffff;
cursor: not-allowed;
opacity: 0.75;
}
.chat__button--secondary {
background: #e5e7eb;
color: #1f2937;
}
.chat__button--secondary:hover {
background: #d1d5db;
}
.chat__status { .chat__status {
padding: 8px 20px 0; padding: 8px 20px 0;
font-size: 12px; font-size: 12px;

View File

@@ -2,17 +2,33 @@ const statusEl = document.getElementById("status");
const messagesEl = document.getElementById("messages"); const messagesEl = document.getElementById("messages");
const formEl = document.getElementById("chat-form"); const formEl = document.getElementById("chat-form");
const inputEl = document.getElementById("chat-input"); const inputEl = document.getElementById("chat-input");
const clearEl = document.getElementById("chat-clear");
const sendButtonEl = formEl?.querySelector("button[type='submit']");
const settingsToggle = document.getElementById("settings-toggle"); const settingsToggle = document.getElementById("settings-toggle");
const settingsDrawer = document.getElementById("settings-drawer"); const settingsDrawer = document.getElementById("settings-drawer");
const settingsClose = document.getElementById("settings-close"); const settingsClose = document.getElementById("settings-close");
const rangeInputs = settingsDrawer const rangeInputs = settingsDrawer
? Array.from(settingsDrawer.querySelectorAll("input[type='range']")) ? Array.from(settingsDrawer.querySelectorAll("input[type='range']"))
: []; : [];
const configInputs = {
topK: document.getElementById("config-top-k"),
topP: document.getElementById("config-top-p"),
temperature: document.getElementById("config-temperature"),
retrieverMaxDocs: document.getElementById("config-retriever-max-docs"),
rerankerMaxDocs: document.getElementById("config-reranker-max-docs"),
};
const scheme = window.location.protocol === "https:" ? "wss" : "ws"; const scheme = window.location.protocol === "https:" ? "wss" : "ws";
const socketUrl = `${scheme}://${window.location.host}/ws`; const socketUrl = `${scheme}://${window.location.host}/ws`;
let socket; let socket;
let streamingBubble = null; let streamingBubble = null;
let loadingBubble = null;
let isAwaitingResponse = false;
const updateSendButtonState = () => {
if (!sendButtonEl) return;
sendButtonEl.disabled = isAwaitingResponse;
};
const addMessage = (text, direction) => { const addMessage = (text, direction) => {
const bubble = document.createElement("div"); const bubble = document.createElement("div");
@@ -22,6 +38,29 @@ const addMessage = (text, direction) => {
messagesEl.scrollTop = messagesEl.scrollHeight; messagesEl.scrollTop = messagesEl.scrollHeight;
}; };
const showLoadingBubble = () => {
if (loadingBubble) return;
loadingBubble = document.createElement("div");
loadingBubble.className = "message message--in message--loading";
loadingBubble.innerHTML =
'<span class="typing-dot"></span><span class="typing-dot"></span><span class="typing-dot"></span>';
messagesEl.appendChild(loadingBubble);
messagesEl.scrollTop = messagesEl.scrollHeight;
};
const removeLoadingBubble = () => {
if (!loadingBubble) return;
loadingBubble.remove();
loadingBubble = null;
};
const clearMessages = () => {
if (!messagesEl) return;
messagesEl.innerHTML = "";
streamingBubble = null;
loadingBubble = null;
};
const connect = () => { const connect = () => {
socket = new WebSocket(socketUrl); socket = new WebSocket(socketUrl);
@@ -39,11 +78,11 @@ const connect = () => {
} }
if (payload.type === "start") { if (payload.type === "start") {
// Keep the loading bubble visible until the first chunk arrives.
// Just prepare the streaming bubble but don't append it yet.
streamingBubble = document.createElement("div"); streamingBubble = document.createElement("div");
streamingBubble.className = "message message--in"; streamingBubble.className = "message message--in";
streamingBubble.textContent = ""; streamingBubble.textContent = "";
messagesEl.appendChild(streamingBubble);
messagesEl.scrollTop = messagesEl.scrollHeight;
return; return;
} }
@@ -51,6 +90,10 @@ const connect = () => {
if (!streamingBubble) { if (!streamingBubble) {
streamingBubble = document.createElement("div"); streamingBubble = document.createElement("div");
streamingBubble.className = "message message--in"; streamingBubble.className = "message message--in";
}
// First chunk: swap loading bubble for the streaming bubble.
if (!streamingBubble.isConnected) {
removeLoadingBubble();
messagesEl.appendChild(streamingBubble); messagesEl.appendChild(streamingBubble);
} }
streamingBubble.textContent += payload.data ?? ""; streamingBubble.textContent += payload.data ?? "";
@@ -58,18 +101,36 @@ const connect = () => {
return; return;
} }
if (payload.type === "report") {
console.log("%cSOURCES", 'border: 2px solid red; padding: 1em; color: red; font-size: 14px; font-weight: bold;')
console.log(payload.sources);
console.log("%cRE-RANKED SOURCES", 'border: 2px solid red; padding: 1em; color: red; font-size: 14px; font-weight: bold;')
console.log(payload.reranked_sources);
return;
}
if (payload.type === "end") { if (payload.type === "end") {
// Safety net: remove loading bubble if no chunks were ever received.
removeLoadingBubble();
streamingBubble = null; streamingBubble = null;
isAwaitingResponse = false;
updateSendButtonState();
return; return;
} }
}); });
socket.addEventListener("close", () => { socket.addEventListener("close", () => {
statusEl.textContent = "Disconnected"; statusEl.textContent = "Disconnected";
removeLoadingBubble();
isAwaitingResponse = false;
updateSendButtonState();
}); });
socket.addEventListener("error", () => { socket.addEventListener("error", () => {
statusEl.textContent = "Connection error"; statusEl.textContent = "Connection error";
removeLoadingBubble();
isAwaitingResponse = false;
updateSendButtonState();
}); });
}; };
@@ -106,6 +167,14 @@ settingsToggle?.addEventListener("click", () => {
settingsClose?.addEventListener("click", closeDrawer); settingsClose?.addEventListener("click", closeDrawer);
clearEl?.addEventListener("click", () => {
clearMessages();
inputEl?.focus();
updateSendButtonState();
});
inputEl?.addEventListener("input", updateSendButtonState);
formEl.addEventListener("submit", (event) => { formEl.addEventListener("submit", (event) => {
event.preventDefault(); event.preventDefault();
const text = inputEl.value.trim(); const text = inputEl.value.trim();
@@ -114,10 +183,22 @@ formEl.addEventListener("submit", (event) => {
addMessage("Not connected.", "in"); addMessage("Not connected.", "in");
return; return;
} }
if (isAwaitingResponse) return;
const config = {
top_k: Number(configInputs.topK?.value ?? 0),
top_p: Number(configInputs.topP?.value ?? 0),
temperature: Number(configInputs.temperature?.value ?? 0),
retriever_max_docs: Number(configInputs.retrieverMaxDocs?.value ?? 0),
reranker_max_results: Number(configInputs.rerankerMaxDocs?.value ?? 0),
};
addMessage(text, "out"); addMessage(text, "out");
socket.send(text); socket.send(JSON.stringify({ message: text, config }));
inputEl.value = ""; inputEl.value = "";
inputEl.focus(); inputEl.focus();
isAwaitingResponse = true;
updateSendButtonState();
showLoadingBubble();
}); });
connect(); connect();
updateSendButtonState();

View File

@@ -4,7 +4,7 @@
<head> <head>
<meta charset="UTF-8" /> <meta charset="UTF-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" /> <meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>Chat</title> <title>AKERN Assistant</title>
<link rel="stylesheet" href="/static/css/chat.css" /> <link rel="stylesheet" href="/static/css/chat.css" />
</head> </head>
@@ -12,7 +12,7 @@
<div class="app" aria-label="Chat application"> <div class="app" aria-label="Chat application">
<section class="chat" aria-label="Chat"> <section class="chat" aria-label="Chat">
<div class="chat__header"> <div class="chat__header">
<span>WebSocket Echo Chat</span> <span>AKERN Assistant</span>
<button class="chat__settings" id="settings-toggle" type="button" aria-haspopup="dialog" <button class="chat__settings" id="settings-toggle" type="button" aria-haspopup="dialog"
aria-expanded="false" aria-controls="settings-drawer">Configurazione</button> aria-expanded="false" aria-controls="settings-drawer">Configurazione</button>
</div> </div>
@@ -22,6 +22,7 @@
<input class="chat__input" id="chat-input" type="text" placeholder="Type a message" autocomplete="off" <input class="chat__input" id="chat-input" type="text" placeholder="Type a message" autocomplete="off"
required /> required />
<button class="chat__button" type="submit">Send</button> <button class="chat__button" type="submit">Send</button>
<button class="chat__button chat__button--secondary" id="chat-clear" type="button">Clear</button>
</form> </form>
</section> </section>
@@ -33,27 +34,31 @@
<form class="drawer__form"> <form class="drawer__form">
<label class="drawer__field"> <label class="drawer__field">
<span class="drawer__label">top_k</span> <span class="drawer__label">top_k</span>
<input class="drawer__range" type="range" min="0" max="100" step="1" value="40" /> <input class="drawer__range" id="config-top-k" type="range" min="0" max="100" step="1" value="40" />
<output class="drawer__value">40</output> <output class="drawer__value">40</output>
</label> </label>
<label class="drawer__field"> <label class="drawer__field">
<span class="drawer__label">top_p</span> <span class="drawer__label">top_p</span>
<input class="drawer__range" type="range" min="0" max="1" step="0.1" value="0.0" /> <input class="drawer__range" id="config-top-p" type="range" min="0" max="1" step="0.1"
value="0.0" />
<output class="drawer__value">0.0</output> <output class="drawer__value">0.0</output>
</label> </label>
<label class="drawer__field"> <label class="drawer__field">
<span class="drawer__label">temperature</span> <span class="drawer__label">temperature</span>
<input class="drawer__range" type="range" min="0" max="1.5" step="0.1" value="0.0" /> <input class="drawer__range" id="config-temperature" type="range" min="0" max="1.5" step="0.1"
value="0.0" />
<output class="drawer__value">0.0</output> <output class="drawer__value">0.0</output>
</label> </label>
<label class="drawer__field"> <label class="drawer__field">
<span class="drawer__label">retriever max docs</span> <span class="drawer__label">retriever max docs</span>
<input class="drawer__range" type="range" min="5" max="100" step="1" value="40" /> <input class="drawer__range" id="config-retriever-max-docs" type="range" min="5" max="100" step="1"
value="40" />
<output class="drawer__value">40</output> <output class="drawer__value">40</output>
</label> </label>
<label class="drawer__field"> <label class="drawer__field">
<span class="drawer__label">reranker max docs</span> <span class="drawer__label">reranker max docs</span>
<input class="drawer__range" type="range" min="5" max="100" step="1" value="20" /> <input class="drawer__range" id="config-reranker-max-docs" type="range" min="5" max="100" step="1"
value="20" />
<output class="drawer__value">20</output> <output class="drawer__value">20</output>
</label> </label>
<p class="drawer__hint">Solo frontend, nessuna logica applicata.</p> <p class="drawer__hint">Solo frontend, nessuna logica applicata.</p>