Files
AKERN-Langchain/app.py
2026-02-18 14:21:31 +01:00

61 lines
2.0 KiB
Python

import json
from fastapi import FastAPI, WebSocket
from fastapi.templating import Jinja2Templates
from fastapi import Request
from fastapi.staticfiles import StaticFiles
from chain import RagChain
from pprint import pprint
app = FastAPI()
templates = Jinja2Templates(directory="templates")
app.mount("/static", StaticFiles(directory="static"), name="static")
@app.get("/")
def home(request: Request):
return templates.TemplateResponse("index.html", {"request": request})
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
try:
while True:
raw_message = await websocket.receive_text()
try:
await websocket.send_json({"type": "start"})
try:
payload = json.loads(raw_message)
message = payload.get("message", "")
config = payload.get("config", {})
except json.JSONDecodeError:
message = raw_message
config = {}
rag_chain = RagChain(
top_k=int(config.get("top_k", 40)),
top_p=float(config.get("top_p", 0.0)),
temperature=float(config.get("temperature", 0.0)),
retriever_max_docs=int(config.get("retriever_max_docs", 40)),
reranker_max_results=int(config.get("reranker_max_results", 20)),
)
async for chunk in rag_chain.stream(message):
await websocket.send_json({"type": "chunk", "data": chunk})
await websocket.send_json(
{
"type": "report",
"sources": rag_chain.getSources(),
"reranked_sources": rag_chain.getRankedSources(),
}
)
finally:
await websocket.send_json({"type": "end"})
except Exception:
pass