prompt template from db

This commit is contained in:
Matteo Rosati
2026-02-18 16:09:31 +01:00
parent b64e97c9d0
commit 3c6c367600
6 changed files with 36 additions and 108 deletions

37
app.py
View File

@@ -1,13 +1,11 @@
import json
from fastapi import FastAPI, WebSocket
from fastapi.templating import Jinja2Templates
from fastapi import Request
from fastapi import FastAPI, Request, WebSocket
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from db import DB
from chain import RagChain
from pprint import pprint
from db import DB
app = FastAPI()
@@ -20,15 +18,9 @@ def home(request: Request):
db = DB()
prompts = db.get_prompts()
return templates.TemplateResponse("index.html", {"request": request, "prompts": prompts})
@app.get("/prova")
def prova(request: Request):
cursor = DB()
prompts = cursor.get_prompts()
return prompts
return templates.TemplateResponse(
"index.html", {"request": request, "prompts": prompts}
)
@app.websocket("/ws")
@@ -48,14 +40,23 @@ async def websocket_endpoint(websocket: WebSocket):
message = raw_message
config = {}
db = DB()
prompt_id = int(config.get("prompt_id", 0))
prompt = db.get_prompt_by_id(prompt_id)
if prompt is None:
await websocket.send_json(
{"type": "chunk", "data": "Error: prompt not found."}
)
continue
rag_chain = RagChain(
question_template=str(prompt.text),
top_k=int(config.get("top_k", 40)),
top_p=float(config.get("top_p", 0.0)),
temperature=float(config.get("temperature", 0.0)),
retriever_max_docs=int(
config.get("retriever_max_docs", 40)),
reranker_max_results=int(
config.get("reranker_max_results", 20)),
retriever_max_docs=int(config.get("retriever_max_docs", 40)),
reranker_max_results=int(config.get("reranker_max_results", 20)),
)
async for chunk in rag_chain.stream(message):