message streaming

This commit is contained in:
Matteo Rosati
2026-02-18 14:21:31 +01:00
parent 3e6fefabbd
commit e1afb6e6c7
5 changed files with 299 additions and 101 deletions

41
app.py
View File

@@ -1,9 +1,12 @@
import json
from fastapi import FastAPI, WebSocket
from fastapi.templating import Jinja2Templates
from fastapi import Request
from fastapi.staticfiles import StaticFiles
from chain import full_chain
from chain import RagChain
from pprint import pprint
app = FastAPI()
@@ -22,10 +25,36 @@ async def websocket_endpoint(websocket: WebSocket):
try:
while True:
message = await websocket.receive_text()
await websocket.send_json({"type": "start"})
async for chunk in full_chain.astream(message):
await websocket.send_json({"type": "chunk", "data": chunk})
await websocket.send_json({"type": "end"})
raw_message = await websocket.receive_text()
try:
await websocket.send_json({"type": "start"})
try:
payload = json.loads(raw_message)
message = payload.get("message", "")
config = payload.get("config", {})
except json.JSONDecodeError:
message = raw_message
config = {}
rag_chain = RagChain(
top_k=int(config.get("top_k", 40)),
top_p=float(config.get("top_p", 0.0)),
temperature=float(config.get("temperature", 0.0)),
retriever_max_docs=int(config.get("retriever_max_docs", 40)),
reranker_max_results=int(config.get("reranker_max_results", 20)),
)
async for chunk in rag_chain.stream(message):
await websocket.send_json({"type": "chunk", "data": chunk})
await websocket.send_json(
{
"type": "report",
"sources": rag_chain.getSources(),
"reranked_sources": rag_chain.getRankedSources(),
}
)
finally:
await websocket.send_json({"type": "end"})
except Exception:
pass