message streaming

This commit is contained in:
Matteo Rosati
2026-02-18 14:21:31 +01:00
parent 3e6fefabbd
commit e1afb6e6c7
5 changed files with 299 additions and 101 deletions

35
app.py
View File

@@ -1,9 +1,12 @@
import json
from fastapi import FastAPI, WebSocket
from fastapi.templating import Jinja2Templates
from fastapi import Request
from fastapi.staticfiles import StaticFiles
from chain import full_chain
from chain import RagChain
from pprint import pprint
app = FastAPI()
@@ -22,10 +25,36 @@ async def websocket_endpoint(websocket: WebSocket):
try:
while True:
message = await websocket.receive_text()
raw_message = await websocket.receive_text()
try:
await websocket.send_json({"type": "start"})
async for chunk in full_chain.astream(message):
try:
payload = json.loads(raw_message)
message = payload.get("message", "")
config = payload.get("config", {})
except json.JSONDecodeError:
message = raw_message
config = {}
rag_chain = RagChain(
top_k=int(config.get("top_k", 40)),
top_p=float(config.get("top_p", 0.0)),
temperature=float(config.get("temperature", 0.0)),
retriever_max_docs=int(config.get("retriever_max_docs", 40)),
reranker_max_results=int(config.get("reranker_max_results", 20)),
)
async for chunk in rag_chain.stream(message):
await websocket.send_json({"type": "chunk", "data": chunk})
await websocket.send_json(
{
"type": "report",
"sources": rag_chain.getSources(),
"reranked_sources": rag_chain.getRankedSources(),
}
)
finally:
await websocket.send_json({"type": "end"})
except Exception:
pass

178
chain.py
View File

@@ -1,5 +1,3 @@
import asyncio
from dotenv import load_dotenv
from langchain_classic.retrievers import ContextualCompressionRetriever
from langchain_core.output_parsers import StrOutputParser
@@ -16,28 +14,109 @@ DATA_STORE = "akern-ds_1771234036654"
MODEL = "gemini-2.5-flash"
LOCATION = "eu"
PRINT_SOURCES = False
# LLM CONFIG
TOP_K = 40
TOP_P = 1
TEMPERATURE = 0.0
MAX_OUTPUT_TOKENS = 65535
RETRIEVER_MAX_DOCS = 50
RERANKER_MAX_RESULTS = 25
with open("prompt.md") as f:
class RagChain:
def __init__(
self,
top_k: int,
top_p: float,
temperature: float,
retriever_max_docs: int,
reranker_max_results: int,
) -> None:
self.top_k = top_k
self.top_p = top_p
self.temperature = temperature
self.retriever_max_docs = retriever_max_docs
self.reranker_max_results = reranker_max_results
with open("prompt.md") as f:
question_template = f.read()
with open("question_rewrite_prompt.md") as f:
with open("question_rewrite_prompt.md") as f:
question_rewrite_template = f.read()
question_prompt = ChatPromptTemplate.from_template(question_template)
question_rewrite_prompt = ChatPromptTemplate.from_template(question_rewrite_template)
question_prompt = ChatPromptTemplate.from_template(question_template)
question_rewrite_prompt = ChatPromptTemplate.from_template(
question_rewrite_template
)
self._retriever_sources: list[dict] = []
self._reranked_sources: list[dict] = []
def format_docs(question: str) -> str:
retrieved_docs = base_retriever.invoke(question)
reranked_docs = compression_retriever.invoke(question)
self._llm = ChatGoogleGenerativeAI(
model=MODEL,
project=PROJECT,
vertexai=True,
top_p=self.top_p,
top_k=self.top_k,
temperature=self.temperature,
max_output_tokens=MAX_OUTPUT_TOKENS,
)
self._base_retriever = VertexAISearchRetriever(
project_id=PROJECT,
data_store_id=DATA_STORE,
max_documents=self.retriever_max_docs,
location_id=LOCATION,
beta=True,
)
self._reranker = VertexAIRank(
project_id=PROJECT,
location_id=LOCATION,
ranking_config="default_ranking_config",
top_n=self.reranker_max_results,
)
self._compression_retriever = ContextualCompressionRetriever(
base_compressor=self._reranker, base_retriever=self._base_retriever
)
question_rewrite_chain = (
{"question": RunnablePassthrough()}
| question_rewrite_prompt
| self._llm
| StrOutputParser()
| RunnableLambda(self._log_rewritten_question)
)
rag_chain = (
{
"context": RunnableLambda(self._format_docs),
"question": RunnablePassthrough(),
}
| question_prompt
| self._llm
| StrOutputParser()
)
self._full_chain = question_rewrite_chain | rag_chain
def _log_rewritten_question(self, rewritten_question: str) -> str:
return rewritten_question
def _format_docs(self, question: str) -> str:
retrieved_docs = self._base_retriever.invoke(question)
reranked_docs = self._compression_retriever.invoke(question)
self._retriever_sources = [
{
"page_content": f"{doc.page_content[:50]}...",
"source": doc.metadata.get("source", ""),
}
for doc in retrieved_docs
]
self._reranked_sources = [
{
"relevance_score": doc.metadata.get("relevance_score", ""),
"page_content": f"{doc.page_content[:50]}...",
}
for doc in reranked_docs
]
if PRINT_SOURCES:
print("========== RETRIEVER DOCUMENTS ==========")
@@ -56,66 +135,11 @@ def format_docs(question: str) -> str:
return "\n\n".join(doc.page_content for doc in reranked_docs)
def getSources(self) -> list[dict]:
return list(self._retriever_sources)
def log_rewritten_question(rewritten_question: str) -> str:
print("=== REWRITTEN QUESTION ===")
print(rewritten_question)
return rewritten_question
def getRankedSources(self) -> list[dict]:
return list(self._reranked_sources)
llm = ChatGoogleGenerativeAI(
model=MODEL,
project=PROJECT,
vertexai=True,
top_p=TOP_P,
top_k=TOP_K,
temperature=TEMPERATURE,
max_output_tokens=MAX_OUTPUT_TOKENS,
)
base_retriever = VertexAISearchRetriever(
project_id=PROJECT,
data_store_id=DATA_STORE,
max_documents=RETRIEVER_MAX_DOCS,
location_id=LOCATION,
beta=True,
)
reranker = VertexAIRank(
project_id=PROJECT,
location_id="eu",
ranking_config="default_ranking_config",
top_n=RERANKER_MAX_RESULTS,
)
compression_retriever = ContextualCompressionRetriever(
base_compressor=reranker, base_retriever=base_retriever
)
question_rewrite_chain = (
{"question": RunnablePassthrough()}
| question_rewrite_prompt
| llm
| StrOutputParser()
| RunnableLambda(log_rewritten_question)
)
rag_chain = (
{"context": RunnableLambda(format_docs), "question": RunnablePassthrough()}
| question_prompt
| llm
| StrOutputParser()
)
full_chain = question_rewrite_chain | rag_chain
async def main():
response = await full_chain.ainvoke(
"Buongiorno, non so se è la mail specifica ma volevo se possibile dei chiarimenti per linterpretazione dei parametri BCM /SMM/ASMM. Mi capita a volte di trovare casi in cui la BCM è aumentata ma allo stesso tempo SMM/ASMM hanno subito una piccola flessione in negativo (o viceversa). Se la parte metabolicamente attiva aumenta perchè può succedere che gli altri compartimenti si riducono?? E allo stesso tempo phA e BCM possono essere inversamente proporzionali?? So che il phA correla con massa e struttura + idratazione."
)
print(response)
if __name__ == "__main__":
asyncio.run(main())
def stream(self, message: str):
return self._full_chain.astream(message)

View File

@@ -94,6 +94,49 @@ body {
color: #111827;
}
/* Typing / loading indicator */
.message--loading {
display: flex;
align-items: center;
gap: 5px;
padding: 12px 16px;
}
.typing-dot {
width: 8px;
height: 8px;
border-radius: 50%;
background: #9ca3af;
animation: typing-bounce 1.2s infinite ease-in-out;
}
.typing-dot:nth-child(1) {
animation-delay: 0s;
}
.typing-dot:nth-child(2) {
animation-delay: 0.2s;
}
.typing-dot:nth-child(3) {
animation-delay: 0.4s;
}
@keyframes typing-bounce {
0%,
60%,
100% {
transform: translateY(0);
background: #9ca3af;
}
30% {
transform: translateY(-6px);
background: #6b7280;
}
}
.chat__footer {
padding: 16px 20px;
border-top: 1px solid #e5e7eb;
@@ -125,6 +168,22 @@ body {
cursor: pointer;
}
.chat__button:disabled {
background: #cbd5f5;
color: #ffffff;
cursor: not-allowed;
opacity: 0.75;
}
.chat__button--secondary {
background: #e5e7eb;
color: #1f2937;
}
.chat__button--secondary:hover {
background: #d1d5db;
}
.chat__status {
padding: 8px 20px 0;
font-size: 12px;

View File

@@ -2,17 +2,33 @@ const statusEl = document.getElementById("status");
const messagesEl = document.getElementById("messages");
const formEl = document.getElementById("chat-form");
const inputEl = document.getElementById("chat-input");
const clearEl = document.getElementById("chat-clear");
const sendButtonEl = formEl?.querySelector("button[type='submit']");
const settingsToggle = document.getElementById("settings-toggle");
const settingsDrawer = document.getElementById("settings-drawer");
const settingsClose = document.getElementById("settings-close");
const rangeInputs = settingsDrawer
? Array.from(settingsDrawer.querySelectorAll("input[type='range']"))
: [];
const configInputs = {
topK: document.getElementById("config-top-k"),
topP: document.getElementById("config-top-p"),
temperature: document.getElementById("config-temperature"),
retrieverMaxDocs: document.getElementById("config-retriever-max-docs"),
rerankerMaxDocs: document.getElementById("config-reranker-max-docs"),
};
const scheme = window.location.protocol === "https:" ? "wss" : "ws";
const socketUrl = `${scheme}://${window.location.host}/ws`;
let socket;
let streamingBubble = null;
let loadingBubble = null;
let isAwaitingResponse = false;
const updateSendButtonState = () => {
if (!sendButtonEl) return;
sendButtonEl.disabled = isAwaitingResponse;
};
const addMessage = (text, direction) => {
const bubble = document.createElement("div");
@@ -22,6 +38,29 @@ const addMessage = (text, direction) => {
messagesEl.scrollTop = messagesEl.scrollHeight;
};
const showLoadingBubble = () => {
if (loadingBubble) return;
loadingBubble = document.createElement("div");
loadingBubble.className = "message message--in message--loading";
loadingBubble.innerHTML =
'<span class="typing-dot"></span><span class="typing-dot"></span><span class="typing-dot"></span>';
messagesEl.appendChild(loadingBubble);
messagesEl.scrollTop = messagesEl.scrollHeight;
};
const removeLoadingBubble = () => {
if (!loadingBubble) return;
loadingBubble.remove();
loadingBubble = null;
};
const clearMessages = () => {
if (!messagesEl) return;
messagesEl.innerHTML = "";
streamingBubble = null;
loadingBubble = null;
};
const connect = () => {
socket = new WebSocket(socketUrl);
@@ -39,11 +78,11 @@ const connect = () => {
}
if (payload.type === "start") {
// Keep the loading bubble visible until the first chunk arrives.
// Just prepare the streaming bubble but don't append it yet.
streamingBubble = document.createElement("div");
streamingBubble.className = "message message--in";
streamingBubble.textContent = "";
messagesEl.appendChild(streamingBubble);
messagesEl.scrollTop = messagesEl.scrollHeight;
return;
}
@@ -51,6 +90,10 @@ const connect = () => {
if (!streamingBubble) {
streamingBubble = document.createElement("div");
streamingBubble.className = "message message--in";
}
// First chunk: swap loading bubble for the streaming bubble.
if (!streamingBubble.isConnected) {
removeLoadingBubble();
messagesEl.appendChild(streamingBubble);
}
streamingBubble.textContent += payload.data ?? "";
@@ -58,18 +101,36 @@ const connect = () => {
return;
}
if (payload.type === "report") {
console.log("%cSOURCES", 'border: 2px solid red; padding: 1em; color: red; font-size: 14px; font-weight: bold;')
console.log(payload.sources);
console.log("%cRE-RANKED SOURCES", 'border: 2px solid red; padding: 1em; color: red; font-size: 14px; font-weight: bold;')
console.log(payload.reranked_sources);
return;
}
if (payload.type === "end") {
// Safety net: remove loading bubble if no chunks were ever received.
removeLoadingBubble();
streamingBubble = null;
isAwaitingResponse = false;
updateSendButtonState();
return;
}
});
socket.addEventListener("close", () => {
statusEl.textContent = "Disconnected";
removeLoadingBubble();
isAwaitingResponse = false;
updateSendButtonState();
});
socket.addEventListener("error", () => {
statusEl.textContent = "Connection error";
removeLoadingBubble();
isAwaitingResponse = false;
updateSendButtonState();
});
};
@@ -106,6 +167,14 @@ settingsToggle?.addEventListener("click", () => {
settingsClose?.addEventListener("click", closeDrawer);
clearEl?.addEventListener("click", () => {
clearMessages();
inputEl?.focus();
updateSendButtonState();
});
inputEl?.addEventListener("input", updateSendButtonState);
formEl.addEventListener("submit", (event) => {
event.preventDefault();
const text = inputEl.value.trim();
@@ -114,10 +183,22 @@ formEl.addEventListener("submit", (event) => {
addMessage("Not connected.", "in");
return;
}
if (isAwaitingResponse) return;
const config = {
top_k: Number(configInputs.topK?.value ?? 0),
top_p: Number(configInputs.topP?.value ?? 0),
temperature: Number(configInputs.temperature?.value ?? 0),
retriever_max_docs: Number(configInputs.retrieverMaxDocs?.value ?? 0),
reranker_max_results: Number(configInputs.rerankerMaxDocs?.value ?? 0),
};
addMessage(text, "out");
socket.send(text);
socket.send(JSON.stringify({ message: text, config }));
inputEl.value = "";
inputEl.focus();
isAwaitingResponse = true;
updateSendButtonState();
showLoadingBubble();
});
connect();
updateSendButtonState();

View File

@@ -4,7 +4,7 @@
<head>
<meta charset="UTF-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>Chat</title>
<title>AKERN Assistant</title>
<link rel="stylesheet" href="/static/css/chat.css" />
</head>
@@ -12,7 +12,7 @@
<div class="app" aria-label="Chat application">
<section class="chat" aria-label="Chat">
<div class="chat__header">
<span>WebSocket Echo Chat</span>
<span>AKERN Assistant</span>
<button class="chat__settings" id="settings-toggle" type="button" aria-haspopup="dialog"
aria-expanded="false" aria-controls="settings-drawer">Configurazione</button>
</div>
@@ -22,6 +22,7 @@
<input class="chat__input" id="chat-input" type="text" placeholder="Type a message" autocomplete="off"
required />
<button class="chat__button" type="submit">Send</button>
<button class="chat__button chat__button--secondary" id="chat-clear" type="button">Clear</button>
</form>
</section>
@@ -33,27 +34,31 @@
<form class="drawer__form">
<label class="drawer__field">
<span class="drawer__label">top_k</span>
<input class="drawer__range" type="range" min="0" max="100" step="1" value="40" />
<input class="drawer__range" id="config-top-k" type="range" min="0" max="100" step="1" value="40" />
<output class="drawer__value">40</output>
</label>
<label class="drawer__field">
<span class="drawer__label">top_p</span>
<input class="drawer__range" type="range" min="0" max="1" step="0.1" value="0.0" />
<input class="drawer__range" id="config-top-p" type="range" min="0" max="1" step="0.1"
value="0.0" />
<output class="drawer__value">0.0</output>
</label>
<label class="drawer__field">
<span class="drawer__label">temperature</span>
<input class="drawer__range" type="range" min="0" max="1.5" step="0.1" value="0.0" />
<input class="drawer__range" id="config-temperature" type="range" min="0" max="1.5" step="0.1"
value="0.0" />
<output class="drawer__value">0.0</output>
</label>
<label class="drawer__field">
<span class="drawer__label">retriever max docs</span>
<input class="drawer__range" type="range" min="5" max="100" step="1" value="40" />
<input class="drawer__range" id="config-retriever-max-docs" type="range" min="5" max="100" step="1"
value="40" />
<output class="drawer__value">40</output>
</label>
<label class="drawer__field">
<span class="drawer__label">reranker max docs</span>
<input class="drawer__range" type="range" min="5" max="100" step="1" value="20" />
<input class="drawer__range" id="config-reranker-max-docs" type="range" min="5" max="100" step="1"
value="20" />
<output class="drawer__value">20</output>
</label>
<p class="drawer__hint">Solo frontend, nessuna logica applicata.</p>