Files
AKERN-Langchain/app.py
2026-02-18 16:09:31 +01:00

77 lines
2.5 KiB
Python

import json
from fastapi import FastAPI, Request, WebSocket
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from chain import RagChain
from db import DB
app = FastAPI()
templates = Jinja2Templates(directory="templates")
app.mount("/static", StaticFiles(directory="static"), name="static")
@app.get("/")
def home(request: Request):
db = DB()
prompts = db.get_prompts()
return templates.TemplateResponse(
"index.html", {"request": request, "prompts": prompts}
)
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
try:
while True:
raw_message = await websocket.receive_text()
try:
await websocket.send_json({"type": "start"})
try:
payload = json.loads(raw_message)
message = payload.get("message", "")
config = payload.get("config", {})
except json.JSONDecodeError:
message = raw_message
config = {}
db = DB()
prompt_id = int(config.get("prompt_id", 0))
prompt = db.get_prompt_by_id(prompt_id)
if prompt is None:
await websocket.send_json(
{"type": "chunk", "data": "Error: prompt not found."}
)
continue
rag_chain = RagChain(
question_template=str(prompt.text),
top_k=int(config.get("top_k", 40)),
top_p=float(config.get("top_p", 0.0)),
temperature=float(config.get("temperature", 0.0)),
retriever_max_docs=int(config.get("retriever_max_docs", 40)),
reranker_max_results=int(config.get("reranker_max_results", 20)),
)
async for chunk in rag_chain.stream(message):
await websocket.send_json({"type": "chunk", "data": chunk})
await websocket.send_json(
{
"type": "report",
"sources": rag_chain.getSources(),
"reranked_sources": rag_chain.getRankedSources(),
"rephrased_question": rag_chain.getRephrasedQuestion(),
}
)
finally:
await websocket.send_json({"type": "end"})
except Exception:
pass