import json from fastapi import FastAPI, WebSocket from fastapi.templating import Jinja2Templates from fastapi import Request from fastapi.staticfiles import StaticFiles from db import DB from chain import RagChain from pprint import pprint app = FastAPI() templates = Jinja2Templates(directory="templates") app.mount("/static", StaticFiles(directory="static"), name="static") @app.get("/") def home(request: Request): db = DB() prompts = db.get_prompts() return templates.TemplateResponse("index.html", {"request": request, "prompts": prompts}) @app.get("/prova") def prova(request: Request): cursor = DB() prompts = cursor.get_prompts() return prompts @app.websocket("/ws") async def websocket_endpoint(websocket: WebSocket): await websocket.accept() try: while True: raw_message = await websocket.receive_text() try: await websocket.send_json({"type": "start"}) try: payload = json.loads(raw_message) message = payload.get("message", "") config = payload.get("config", {}) except json.JSONDecodeError: message = raw_message config = {} rag_chain = RagChain( top_k=int(config.get("top_k", 40)), top_p=float(config.get("top_p", 0.0)), temperature=float(config.get("temperature", 0.0)), retriever_max_docs=int( config.get("retriever_max_docs", 40)), reranker_max_results=int( config.get("reranker_max_results", 20)), ) async for chunk in rag_chain.stream(message): await websocket.send_json({"type": "chunk", "data": chunk}) await websocket.send_json( { "type": "report", "sources": rag_chain.getSources(), "reranked_sources": rag_chain.getRankedSources(), "rephrased_question": rag_chain.getRephrasedQuestion(), } ) finally: await websocket.send_json({"type": "end"}) except Exception: pass