message streaming
This commit is contained in:
41
app.py
41
app.py
@@ -1,9 +1,12 @@
|
||||
import json
|
||||
|
||||
from fastapi import FastAPI, WebSocket
|
||||
from fastapi.templating import Jinja2Templates
|
||||
from fastapi import Request
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
|
||||
from chain import full_chain
|
||||
from chain import RagChain
|
||||
from pprint import pprint
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
@@ -22,10 +25,36 @@ async def websocket_endpoint(websocket: WebSocket):
|
||||
|
||||
try:
|
||||
while True:
|
||||
message = await websocket.receive_text()
|
||||
await websocket.send_json({"type": "start"})
|
||||
async for chunk in full_chain.astream(message):
|
||||
await websocket.send_json({"type": "chunk", "data": chunk})
|
||||
await websocket.send_json({"type": "end"})
|
||||
raw_message = await websocket.receive_text()
|
||||
try:
|
||||
await websocket.send_json({"type": "start"})
|
||||
try:
|
||||
payload = json.loads(raw_message)
|
||||
message = payload.get("message", "")
|
||||
config = payload.get("config", {})
|
||||
except json.JSONDecodeError:
|
||||
message = raw_message
|
||||
config = {}
|
||||
|
||||
rag_chain = RagChain(
|
||||
top_k=int(config.get("top_k", 40)),
|
||||
top_p=float(config.get("top_p", 0.0)),
|
||||
temperature=float(config.get("temperature", 0.0)),
|
||||
retriever_max_docs=int(config.get("retriever_max_docs", 40)),
|
||||
reranker_max_results=int(config.get("reranker_max_results", 20)),
|
||||
)
|
||||
|
||||
async for chunk in rag_chain.stream(message):
|
||||
await websocket.send_json({"type": "chunk", "data": chunk})
|
||||
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "report",
|
||||
"sources": rag_chain.getSources(),
|
||||
"reranked_sources": rag_chain.getRankedSources(),
|
||||
}
|
||||
)
|
||||
finally:
|
||||
await websocket.send_json({"type": "end"})
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user