import json import os import secrets from contextlib import asynccontextmanager from dotenv import load_dotenv from fastapi import Depends, FastAPI, HTTPException, Request, WebSocket, status from fastapi.security import HTTPBasic, HTTPBasicCredentials from fastapi.staticfiles import StaticFiles from fastapi.templating import Jinja2Templates from sqlalchemy.ext.asyncio import create_async_engine from chain import RagChain from db import DB load_dotenv() @asynccontextmanager async def lifespan(app: FastAPI): engine = create_async_engine("sqlite+aiosqlite:///example.db") app.state.db = DB(engine) yield await engine.dispose() app = FastAPI(lifespan=lifespan) security = HTTPBasic() templates = Jinja2Templates(directory="templates") app.mount("/static", StaticFiles(directory="static"), name="static") def verify_credentials(credentials: HTTPBasicCredentials = Depends(security)): correct_username = secrets.compare_digest(credentials.username, os.environ["AUTH_USER"]) correct_password = secrets.compare_digest( credentials.password, os.environ["PASSWORD"] ) if not (correct_username and correct_password): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Unauthorized", headers={"WWW-Authenticate": "Basic"}, ) @app.get("/") async def home(request: Request, _: None = Depends(verify_credentials)): db: DB = request.app.state.db prompts = await db.get_prompts() return templates.TemplateResponse( "index.html", {"request": request, "prompts": prompts} ) @app.websocket("/ws") async def websocket_endpoint(websocket: WebSocket): await websocket.accept() db: DB = websocket.app.state.db try: while True: raw_message = await websocket.receive_text() try: await websocket.send_json({"type": "start"}) try: payload = json.loads(raw_message) message = payload.get("message", "") config = payload.get("config", {}) except json.JSONDecodeError: message = raw_message config = {} prompt_id = int(config.get("prompt_id", 0)) prompt = await db.get_prompt_by_id(prompt_id) if prompt is None: await websocket.send_json( {"type": "chunk", "data": "Error: prompt not found."} ) continue rag_chain = RagChain( question_template=str(prompt.text), top_k=int(config.get("top_k", 40)), top_p=float(config.get("top_p", 0.0)), temperature=float(config.get("temperature", 0.0)), retriever_max_docs=int(config.get("retriever_max_docs", 40)), reranker_max_results=int(config.get("reranker_max_results", 20)), ) async for chunk in rag_chain.stream(message): await websocket.send_json({"type": "chunk", "data": chunk}) await websocket.send_json( { "type": "report", "sources": rag_chain.getSources(), "reranked_sources": rag_chain.getRankedSources(), "rephrased_question": rag_chain.getRephrasedQuestion(), } ) finally: await websocket.send_json({"type": "end"}) except Exception: pass if __name__ == "__main__": import uvicorn uvicorn.run("app:app", host="0.0.0.0", port=int(os.getenv("PORT", 8080)))