import json from fastapi import FastAPI, WebSocket from fastapi.templating import Jinja2Templates from fastapi import Request from fastapi.staticfiles import StaticFiles from chain import RagChain from pprint import pprint app = FastAPI() templates = Jinja2Templates(directory="templates") app.mount("/static", StaticFiles(directory="static"), name="static") @app.get("/") def home(request: Request): return templates.TemplateResponse("index.html", {"request": request}) @app.websocket("/ws") async def websocket_endpoint(websocket: WebSocket): await websocket.accept() try: while True: raw_message = await websocket.receive_text() try: await websocket.send_json({"type": "start"}) try: payload = json.loads(raw_message) message = payload.get("message", "") config = payload.get("config", {}) except json.JSONDecodeError: message = raw_message config = {} rag_chain = RagChain( top_k=int(config.get("top_k", 40)), top_p=float(config.get("top_p", 0.0)), temperature=float(config.get("temperature", 0.0)), retriever_max_docs=int(config.get("retriever_max_docs", 40)), reranker_max_results=int(config.get("reranker_max_results", 20)), ) async for chunk in rag_chain.stream(message): await websocket.send_json({"type": "chunk", "data": chunk}) await websocket.send_json( { "type": "report", "sources": rag_chain.getSources(), "reranked_sources": rag_chain.getRankedSources(), } ) finally: await websocket.send_json({"type": "end"}) except Exception: pass