prompt template from db

This commit is contained in:
Matteo Rosati
2026-02-18 16:09:31 +01:00
parent b64e97c9d0
commit 3c6c367600
6 changed files with 36 additions and 108 deletions

37
app.py
View File

@@ -1,13 +1,11 @@
import json
from fastapi import FastAPI, WebSocket
from fastapi.templating import Jinja2Templates
from fastapi import Request
from fastapi import FastAPI, Request, WebSocket
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from db import DB
from chain import RagChain
from pprint import pprint
from db import DB
app = FastAPI()
@@ -20,15 +18,9 @@ def home(request: Request):
db = DB()
prompts = db.get_prompts()
return templates.TemplateResponse("index.html", {"request": request, "prompts": prompts})
@app.get("/prova")
def prova(request: Request):
cursor = DB()
prompts = cursor.get_prompts()
return prompts
return templates.TemplateResponse(
"index.html", {"request": request, "prompts": prompts}
)
@app.websocket("/ws")
@@ -48,14 +40,23 @@ async def websocket_endpoint(websocket: WebSocket):
message = raw_message
config = {}
db = DB()
prompt_id = int(config.get("prompt_id", 0))
prompt = db.get_prompt_by_id(prompt_id)
if prompt is None:
await websocket.send_json(
{"type": "chunk", "data": "Error: prompt not found."}
)
continue
rag_chain = RagChain(
question_template=str(prompt.text),
top_k=int(config.get("top_k", 40)),
top_p=float(config.get("top_p", 0.0)),
temperature=float(config.get("temperature", 0.0)),
retriever_max_docs=int(
config.get("retriever_max_docs", 40)),
reranker_max_results=int(
config.get("reranker_max_results", 20)),
retriever_max_docs=int(config.get("retriever_max_docs", 40)),
reranker_max_results=int(config.get("reranker_max_results", 20)),
)
async for chunk in rag_chain.stream(message):

View File

@@ -20,6 +20,7 @@ MAX_OUTPUT_TOKENS = 65535
class RagChain:
def __init__(
self,
question_template: str,
top_k: int,
top_p: float,
temperature: float,
@@ -32,8 +33,8 @@ class RagChain:
self.retriever_max_docs = retriever_max_docs
self.reranker_max_results = reranker_max_results
with open("prompt.md") as f:
question_template = f.read()
# with open("prompt.md") as f:
# question_template = f.read()
with open("question_rewrite_prompt.md") as f:
question_rewrite_template = f.read()

8
db.py
View File

@@ -1,4 +1,5 @@
from typing import List
from sqlalchemy import create_engine
from sqlalchemy.orm import Session
@@ -7,9 +8,12 @@ from models.orm import Prompt
class DB:
def __init__(self, db: str = "sqlite:///example.db"):
self.engine = create_engine(
db, connect_args={"check_same_thread": False})
self.engine = create_engine(db, connect_args={"check_same_thread": False})
def get_prompts(self) -> List[Prompt]:
with Session(self.engine) as session:
return session.query(Prompt).all()
def get_prompt_by_id(self, prompt_id: int) -> Prompt | None:
with Session(self.engine) as session:
return session.query(Prompt).filter(Prompt.id == prompt_id).first()

View File

@@ -53,21 +53,6 @@ body {
gap: 12px;
}
.chat__settings {
padding: 6px 12px;
border-radius: 999px;
border: 1px solid #d1d5db;
background: #ffffff;
color: #1f2937;
font-weight: 600;
font-size: 13px;
cursor: pointer;
}
.chat__settings:hover {
background: #f3f4f6;
}
.chat__messages {
flex: 1;
padding: 16px 20px;
@@ -204,18 +189,11 @@ body {
display: flex;
flex-direction: column;
gap: 16px;
transform: translateX(100%);
transition: transform 0.25s ease;
}
.drawer--open {
transform: translateX(0%);
}
.drawer__header {
display: flex;
align-items: center;
justify-content: space-between;
gap: 12px;
}
@@ -224,14 +202,6 @@ body {
font-size: 18px;
}
.drawer__close {
border: 1px solid #d1d5db;
background: #ffffff;
border-radius: 8px;
padding: 4px 8px;
cursor: pointer;
}
.drawer__form {
display: flex;
flex-direction: column;
@@ -261,12 +231,3 @@ body {
font-variant-numeric: tabular-nums;
}
@media (max-width: 900px) {
.drawer {
position: absolute;
right: 0;
top: 0;
height: 100vh;
box-shadow: -12px 0 24px rgba(15, 23, 42, 0.08);
}
}

View File

@@ -4,9 +4,7 @@ const formEl = document.getElementById("chat-form");
const inputEl = document.getElementById("chat-input");
const clearEl = document.getElementById("chat-clear");
const sendButtonEl = formEl?.querySelector("button[type='submit']");
const settingsToggle = document.getElementById("settings-toggle");
const settingsDrawer = document.getElementById("settings-drawer");
const settingsClose = document.getElementById("settings-close");
const rangeInputs = settingsDrawer
? Array.from(settingsDrawer.querySelectorAll("input[type='range']"))
: [];
@@ -16,6 +14,7 @@ const configInputs = {
temperature: document.getElementById("config-temperature"),
retrieverMaxDocs: document.getElementById("config-retriever-max-docs"),
rerankerMaxDocs: document.getElementById("config-reranker-max-docs"),
promptId: document.getElementById("config-prompt"),
};
const scheme = window.location.protocol === "https:" ? "wss" : "ws";
@@ -156,28 +155,6 @@ rangeInputs.forEach((input) => {
input.addEventListener("input", () => updateDrawerValue(input));
});
const closeDrawer = () => {
settingsDrawer?.classList.remove("drawer--open");
if (settingsDrawer) settingsDrawer.setAttribute("aria-hidden", "true");
if (settingsToggle) settingsToggle.setAttribute("aria-expanded", "false");
};
const openDrawer = () => {
settingsDrawer?.classList.add("drawer--open");
if (settingsDrawer) settingsDrawer.setAttribute("aria-hidden", "false");
if (settingsToggle) settingsToggle.setAttribute("aria-expanded", "true");
};
settingsToggle?.addEventListener("click", () => {
if (settingsDrawer?.classList.contains("drawer--open")) {
closeDrawer();
} else {
openDrawer();
}
});
settingsClose?.addEventListener("click", closeDrawer);
clearEl?.addEventListener("click", () => {
clearMessages();
inputEl?.focus();
@@ -201,6 +178,7 @@ formEl.addEventListener("submit", (event) => {
temperature: Number(configInputs.temperature?.value ?? 0),
retriever_max_docs: Number(configInputs.retrieverMaxDocs?.value ?? 0),
reranker_max_results: Number(configInputs.rerankerMaxDocs?.value ?? 0),
prompt_id: Number(configInputs.promptId?.value ?? 0),
};
addMessage(text, "out");
socket.send(JSON.stringify({ message: text, config }));

View File

@@ -12,16 +12,6 @@
<section class="chat" aria-label="Chat">
<div class="chat__header">
<span>AKERN Assistant</span>
<button
class="chat__settings"
id="settings-toggle"
type="button"
aria-haspopup="dialog"
aria-expanded="false"
aria-controls="settings-drawer"
>
Configurazione
</button>
</div>
<div class="chat__messages" id="messages"></div>
<div class="chat__status" id="status">Connecting…</div>
@@ -49,18 +39,9 @@
class="drawer"
id="settings-drawer"
aria-label="Configurazione"
aria-hidden="true"
>
<div class="drawer__header">
<h2 class="drawer__title">Configurazione</h2>
<button
class="drawer__close"
id="settings-close"
type="button"
aria-label="Chiudi"
>
</button>
</div>
<form class="drawer__form">
<label class="drawer__field">
@@ -128,11 +109,13 @@
/>
<output class="drawer__value">20</output>
</label>
<label for="drawer__field">
<label class="drawer__field">
<span class="drawer__label">Prompts</span>
<!-- TODO -->
<!-- Here i need a select -->
<!-- {% for prompt in prompts %} {{ prompt.name }} {{ prompt.id }} {% endfor %} -->
<select class="drawer__select" id="config-prompt">
{% for prompt in prompts %}
<option value="{{ prompt.id }}">{{ prompt.name }}</option>
{% endfor %}
</select>
</label>
</form>
</aside>