Files
AKERN-Langchain/app.py
Matteo Rosati b64e97c9d0 add db
2026-02-18 15:39:01 +01:00

76 lines
2.3 KiB
Python

import json
from fastapi import FastAPI, WebSocket
from fastapi.templating import Jinja2Templates
from fastapi import Request
from fastapi.staticfiles import StaticFiles
from db import DB
from chain import RagChain
from pprint import pprint
app = FastAPI()
templates = Jinja2Templates(directory="templates")
app.mount("/static", StaticFiles(directory="static"), name="static")
@app.get("/")
def home(request: Request):
db = DB()
prompts = db.get_prompts()
return templates.TemplateResponse("index.html", {"request": request, "prompts": prompts})
@app.get("/prova")
def prova(request: Request):
cursor = DB()
prompts = cursor.get_prompts()
return prompts
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
try:
while True:
raw_message = await websocket.receive_text()
try:
await websocket.send_json({"type": "start"})
try:
payload = json.loads(raw_message)
message = payload.get("message", "")
config = payload.get("config", {})
except json.JSONDecodeError:
message = raw_message
config = {}
rag_chain = RagChain(
top_k=int(config.get("top_k", 40)),
top_p=float(config.get("top_p", 0.0)),
temperature=float(config.get("temperature", 0.0)),
retriever_max_docs=int(
config.get("retriever_max_docs", 40)),
reranker_max_results=int(
config.get("reranker_max_results", 20)),
)
async for chunk in rag_chain.stream(message):
await websocket.send_json({"type": "chunk", "data": chunk})
await websocket.send_json(
{
"type": "report",
"sources": rag_chain.getSources(),
"reranked_sources": rag_chain.getRankedSources(),
"rephrased_question": rag_chain.getRephrasedQuestion(),
}
)
finally:
await websocket.send_json({"type": "end"})
except Exception:
pass