76 lines
2.3 KiB
Python
76 lines
2.3 KiB
Python
import json
|
|
|
|
from fastapi import FastAPI, WebSocket
|
|
from fastapi.templating import Jinja2Templates
|
|
from fastapi import Request
|
|
from fastapi.staticfiles import StaticFiles
|
|
|
|
from db import DB
|
|
from chain import RagChain
|
|
from pprint import pprint
|
|
|
|
app = FastAPI()
|
|
|
|
templates = Jinja2Templates(directory="templates")
|
|
app.mount("/static", StaticFiles(directory="static"), name="static")
|
|
|
|
|
|
@app.get("/")
|
|
def home(request: Request):
|
|
db = DB()
|
|
prompts = db.get_prompts()
|
|
|
|
return templates.TemplateResponse("index.html", {"request": request, "prompts": prompts})
|
|
|
|
|
|
@app.get("/prova")
|
|
def prova(request: Request):
|
|
cursor = DB()
|
|
prompts = cursor.get_prompts()
|
|
|
|
return prompts
|
|
|
|
|
|
@app.websocket("/ws")
|
|
async def websocket_endpoint(websocket: WebSocket):
|
|
await websocket.accept()
|
|
|
|
try:
|
|
while True:
|
|
raw_message = await websocket.receive_text()
|
|
try:
|
|
await websocket.send_json({"type": "start"})
|
|
try:
|
|
payload = json.loads(raw_message)
|
|
message = payload.get("message", "")
|
|
config = payload.get("config", {})
|
|
except json.JSONDecodeError:
|
|
message = raw_message
|
|
config = {}
|
|
|
|
rag_chain = RagChain(
|
|
top_k=int(config.get("top_k", 40)),
|
|
top_p=float(config.get("top_p", 0.0)),
|
|
temperature=float(config.get("temperature", 0.0)),
|
|
retriever_max_docs=int(
|
|
config.get("retriever_max_docs", 40)),
|
|
reranker_max_results=int(
|
|
config.get("reranker_max_results", 20)),
|
|
)
|
|
|
|
async for chunk in rag_chain.stream(message):
|
|
await websocket.send_json({"type": "chunk", "data": chunk})
|
|
|
|
await websocket.send_json(
|
|
{
|
|
"type": "report",
|
|
"sources": rag_chain.getSources(),
|
|
"reranked_sources": rag_chain.getRankedSources(),
|
|
"rephrased_question": rag_chain.getRephrasedQuestion(),
|
|
}
|
|
)
|
|
finally:
|
|
await websocket.send_json({"type": "end"})
|
|
except Exception:
|
|
pass
|