import json from fastapi import FastAPI, Request, WebSocket from fastapi.staticfiles import StaticFiles from fastapi.templating import Jinja2Templates from chain import RagChain from db import DB app = FastAPI() templates = Jinja2Templates(directory="templates") app.mount("/static", StaticFiles(directory="static"), name="static") @app.get("/") def home(request: Request): db = DB() prompts = db.get_prompts() return templates.TemplateResponse( "index.html", {"request": request, "prompts": prompts} ) @app.websocket("/ws") async def websocket_endpoint(websocket: WebSocket): await websocket.accept() try: while True: raw_message = await websocket.receive_text() try: await websocket.send_json({"type": "start"}) try: payload = json.loads(raw_message) message = payload.get("message", "") config = payload.get("config", {}) except json.JSONDecodeError: message = raw_message config = {} db = DB() prompt_id = int(config.get("prompt_id", 0)) prompt = db.get_prompt_by_id(prompt_id) if prompt is None: await websocket.send_json( {"type": "chunk", "data": "Error: prompt not found."} ) continue rag_chain = RagChain( question_template=str(prompt.text), top_k=int(config.get("top_k", 40)), top_p=float(config.get("top_p", 0.0)), temperature=float(config.get("temperature", 0.0)), retriever_max_docs=int(config.get("retriever_max_docs", 40)), reranker_max_results=int(config.get("reranker_max_results", 20)), ) async for chunk in rag_chain.stream(message): await websocket.send_json({"type": "chunk", "data": chunk}) await websocket.send_json( { "type": "report", "sources": rag_chain.getSources(), "reranked_sources": rag_chain.getRankedSources(), "rephrased_question": rag_chain.getRephrasedQuestion(), } ) finally: await websocket.send_json({"type": "end"}) except Exception: pass