89 lines
2.8 KiB
Python
89 lines
2.8 KiB
Python
import json
|
|
from contextlib import asynccontextmanager
|
|
|
|
from fastapi import FastAPI, Request, WebSocket
|
|
from fastapi.staticfiles import StaticFiles
|
|
from fastapi.templating import Jinja2Templates
|
|
from sqlalchemy.ext.asyncio import create_async_engine
|
|
|
|
from chain import RagChain
|
|
from db import DB
|
|
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
engine = create_async_engine("sqlite+aiosqlite:///example.db")
|
|
app.state.db = DB(engine)
|
|
yield
|
|
await engine.dispose()
|
|
|
|
|
|
app = FastAPI(lifespan=lifespan)
|
|
|
|
templates = Jinja2Templates(directory="templates")
|
|
app.mount("/static", StaticFiles(directory="static"), name="static")
|
|
|
|
|
|
@app.get("/")
|
|
async def home(request: Request):
|
|
db: DB = request.app.state.db
|
|
prompts = await db.get_prompts()
|
|
|
|
return templates.TemplateResponse(
|
|
"index.html", {"request": request, "prompts": prompts}
|
|
)
|
|
|
|
|
|
@app.websocket("/ws")
|
|
async def websocket_endpoint(websocket: WebSocket):
|
|
await websocket.accept()
|
|
|
|
db: DB = websocket.app.state.db
|
|
|
|
try:
|
|
while True:
|
|
raw_message = await websocket.receive_text()
|
|
try:
|
|
await websocket.send_json({"type": "start"})
|
|
try:
|
|
payload = json.loads(raw_message)
|
|
message = payload.get("message", "")
|
|
config = payload.get("config", {})
|
|
except json.JSONDecodeError:
|
|
message = raw_message
|
|
config = {}
|
|
|
|
prompt_id = int(config.get("prompt_id", 0))
|
|
prompt = await db.get_prompt_by_id(prompt_id)
|
|
|
|
if prompt is None:
|
|
await websocket.send_json(
|
|
{"type": "chunk", "data": "Error: prompt not found."}
|
|
)
|
|
continue
|
|
|
|
rag_chain = RagChain(
|
|
question_template=str(prompt.text),
|
|
top_k=int(config.get("top_k", 40)),
|
|
top_p=float(config.get("top_p", 0.0)),
|
|
temperature=float(config.get("temperature", 0.0)),
|
|
retriever_max_docs=int(config.get("retriever_max_docs", 40)),
|
|
reranker_max_results=int(config.get("reranker_max_results", 20)),
|
|
)
|
|
|
|
async for chunk in rag_chain.stream(message):
|
|
await websocket.send_json({"type": "chunk", "data": chunk})
|
|
|
|
await websocket.send_json(
|
|
{
|
|
"type": "report",
|
|
"sources": rag_chain.getSources(),
|
|
"reranked_sources": rag_chain.getRankedSources(),
|
|
"rephrased_question": rag_chain.getRephrasedQuestion(),
|
|
}
|
|
)
|
|
finally:
|
|
await websocket.send_json({"type": "end"})
|
|
except Exception:
|
|
pass
|