prompt template from db

This commit is contained in:
Matteo Rosati
2026-02-18 16:09:31 +01:00
parent b64e97c9d0
commit 3c6c367600
6 changed files with 36 additions and 108 deletions

37
app.py
View File

@@ -1,13 +1,11 @@
import json import json
from fastapi import FastAPI, WebSocket from fastapi import FastAPI, Request, WebSocket
from fastapi.templating import Jinja2Templates
from fastapi import Request
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from db import DB
from chain import RagChain from chain import RagChain
from pprint import pprint from db import DB
app = FastAPI() app = FastAPI()
@@ -20,15 +18,9 @@ def home(request: Request):
db = DB() db = DB()
prompts = db.get_prompts() prompts = db.get_prompts()
return templates.TemplateResponse("index.html", {"request": request, "prompts": prompts}) return templates.TemplateResponse(
"index.html", {"request": request, "prompts": prompts}
)
@app.get("/prova")
def prova(request: Request):
cursor = DB()
prompts = cursor.get_prompts()
return prompts
@app.websocket("/ws") @app.websocket("/ws")
@@ -48,14 +40,23 @@ async def websocket_endpoint(websocket: WebSocket):
message = raw_message message = raw_message
config = {} config = {}
db = DB()
prompt_id = int(config.get("prompt_id", 0))
prompt = db.get_prompt_by_id(prompt_id)
if prompt is None:
await websocket.send_json(
{"type": "chunk", "data": "Error: prompt not found."}
)
continue
rag_chain = RagChain( rag_chain = RagChain(
question_template=str(prompt.text),
top_k=int(config.get("top_k", 40)), top_k=int(config.get("top_k", 40)),
top_p=float(config.get("top_p", 0.0)), top_p=float(config.get("top_p", 0.0)),
temperature=float(config.get("temperature", 0.0)), temperature=float(config.get("temperature", 0.0)),
retriever_max_docs=int( retriever_max_docs=int(config.get("retriever_max_docs", 40)),
config.get("retriever_max_docs", 40)), reranker_max_results=int(config.get("reranker_max_results", 20)),
reranker_max_results=int(
config.get("reranker_max_results", 20)),
) )
async for chunk in rag_chain.stream(message): async for chunk in rag_chain.stream(message):

View File

@@ -20,6 +20,7 @@ MAX_OUTPUT_TOKENS = 65535
class RagChain: class RagChain:
def __init__( def __init__(
self, self,
question_template: str,
top_k: int, top_k: int,
top_p: float, top_p: float,
temperature: float, temperature: float,
@@ -32,8 +33,8 @@ class RagChain:
self.retriever_max_docs = retriever_max_docs self.retriever_max_docs = retriever_max_docs
self.reranker_max_results = reranker_max_results self.reranker_max_results = reranker_max_results
with open("prompt.md") as f: # with open("prompt.md") as f:
question_template = f.read() # question_template = f.read()
with open("question_rewrite_prompt.md") as f: with open("question_rewrite_prompt.md") as f:
question_rewrite_template = f.read() question_rewrite_template = f.read()

8
db.py
View File

@@ -1,4 +1,5 @@
from typing import List from typing import List
from sqlalchemy import create_engine from sqlalchemy import create_engine
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@@ -7,9 +8,12 @@ from models.orm import Prompt
class DB: class DB:
def __init__(self, db: str = "sqlite:///example.db"): def __init__(self, db: str = "sqlite:///example.db"):
self.engine = create_engine( self.engine = create_engine(db, connect_args={"check_same_thread": False})
db, connect_args={"check_same_thread": False})
def get_prompts(self) -> List[Prompt]: def get_prompts(self) -> List[Prompt]:
with Session(self.engine) as session: with Session(self.engine) as session:
return session.query(Prompt).all() return session.query(Prompt).all()
def get_prompt_by_id(self, prompt_id: int) -> Prompt | None:
with Session(self.engine) as session:
return session.query(Prompt).filter(Prompt.id == prompt_id).first()

View File

@@ -53,21 +53,6 @@ body {
gap: 12px; gap: 12px;
} }
.chat__settings {
padding: 6px 12px;
border-radius: 999px;
border: 1px solid #d1d5db;
background: #ffffff;
color: #1f2937;
font-weight: 600;
font-size: 13px;
cursor: pointer;
}
.chat__settings:hover {
background: #f3f4f6;
}
.chat__messages { .chat__messages {
flex: 1; flex: 1;
padding: 16px 20px; padding: 16px 20px;
@@ -204,18 +189,11 @@ body {
display: flex; display: flex;
flex-direction: column; flex-direction: column;
gap: 16px; gap: 16px;
transform: translateX(100%);
transition: transform 0.25s ease;
}
.drawer--open {
transform: translateX(0%);
} }
.drawer__header { .drawer__header {
display: flex; display: flex;
align-items: center; align-items: center;
justify-content: space-between;
gap: 12px; gap: 12px;
} }
@@ -224,14 +202,6 @@ body {
font-size: 18px; font-size: 18px;
} }
.drawer__close {
border: 1px solid #d1d5db;
background: #ffffff;
border-radius: 8px;
padding: 4px 8px;
cursor: pointer;
}
.drawer__form { .drawer__form {
display: flex; display: flex;
flex-direction: column; flex-direction: column;
@@ -261,12 +231,3 @@ body {
font-variant-numeric: tabular-nums; font-variant-numeric: tabular-nums;
} }
@media (max-width: 900px) {
.drawer {
position: absolute;
right: 0;
top: 0;
height: 100vh;
box-shadow: -12px 0 24px rgba(15, 23, 42, 0.08);
}
}

View File

@@ -4,9 +4,7 @@ const formEl = document.getElementById("chat-form");
const inputEl = document.getElementById("chat-input"); const inputEl = document.getElementById("chat-input");
const clearEl = document.getElementById("chat-clear"); const clearEl = document.getElementById("chat-clear");
const sendButtonEl = formEl?.querySelector("button[type='submit']"); const sendButtonEl = formEl?.querySelector("button[type='submit']");
const settingsToggle = document.getElementById("settings-toggle");
const settingsDrawer = document.getElementById("settings-drawer"); const settingsDrawer = document.getElementById("settings-drawer");
const settingsClose = document.getElementById("settings-close");
const rangeInputs = settingsDrawer const rangeInputs = settingsDrawer
? Array.from(settingsDrawer.querySelectorAll("input[type='range']")) ? Array.from(settingsDrawer.querySelectorAll("input[type='range']"))
: []; : [];
@@ -16,6 +14,7 @@ const configInputs = {
temperature: document.getElementById("config-temperature"), temperature: document.getElementById("config-temperature"),
retrieverMaxDocs: document.getElementById("config-retriever-max-docs"), retrieverMaxDocs: document.getElementById("config-retriever-max-docs"),
rerankerMaxDocs: document.getElementById("config-reranker-max-docs"), rerankerMaxDocs: document.getElementById("config-reranker-max-docs"),
promptId: document.getElementById("config-prompt"),
}; };
const scheme = window.location.protocol === "https:" ? "wss" : "ws"; const scheme = window.location.protocol === "https:" ? "wss" : "ws";
@@ -156,28 +155,6 @@ rangeInputs.forEach((input) => {
input.addEventListener("input", () => updateDrawerValue(input)); input.addEventListener("input", () => updateDrawerValue(input));
}); });
const closeDrawer = () => {
settingsDrawer?.classList.remove("drawer--open");
if (settingsDrawer) settingsDrawer.setAttribute("aria-hidden", "true");
if (settingsToggle) settingsToggle.setAttribute("aria-expanded", "false");
};
const openDrawer = () => {
settingsDrawer?.classList.add("drawer--open");
if (settingsDrawer) settingsDrawer.setAttribute("aria-hidden", "false");
if (settingsToggle) settingsToggle.setAttribute("aria-expanded", "true");
};
settingsToggle?.addEventListener("click", () => {
if (settingsDrawer?.classList.contains("drawer--open")) {
closeDrawer();
} else {
openDrawer();
}
});
settingsClose?.addEventListener("click", closeDrawer);
clearEl?.addEventListener("click", () => { clearEl?.addEventListener("click", () => {
clearMessages(); clearMessages();
inputEl?.focus(); inputEl?.focus();
@@ -201,6 +178,7 @@ formEl.addEventListener("submit", (event) => {
temperature: Number(configInputs.temperature?.value ?? 0), temperature: Number(configInputs.temperature?.value ?? 0),
retriever_max_docs: Number(configInputs.retrieverMaxDocs?.value ?? 0), retriever_max_docs: Number(configInputs.retrieverMaxDocs?.value ?? 0),
reranker_max_results: Number(configInputs.rerankerMaxDocs?.value ?? 0), reranker_max_results: Number(configInputs.rerankerMaxDocs?.value ?? 0),
prompt_id: Number(configInputs.promptId?.value ?? 0),
}; };
addMessage(text, "out"); addMessage(text, "out");
socket.send(JSON.stringify({ message: text, config })); socket.send(JSON.stringify({ message: text, config }));

View File

@@ -12,16 +12,6 @@
<section class="chat" aria-label="Chat"> <section class="chat" aria-label="Chat">
<div class="chat__header"> <div class="chat__header">
<span>AKERN Assistant</span> <span>AKERN Assistant</span>
<button
class="chat__settings"
id="settings-toggle"
type="button"
aria-haspopup="dialog"
aria-expanded="false"
aria-controls="settings-drawer"
>
Configurazione
</button>
</div> </div>
<div class="chat__messages" id="messages"></div> <div class="chat__messages" id="messages"></div>
<div class="chat__status" id="status">Connecting…</div> <div class="chat__status" id="status">Connecting…</div>
@@ -49,18 +39,9 @@
class="drawer" class="drawer"
id="settings-drawer" id="settings-drawer"
aria-label="Configurazione" aria-label="Configurazione"
aria-hidden="true"
> >
<div class="drawer__header"> <div class="drawer__header">
<h2 class="drawer__title">Configurazione</h2> <h2 class="drawer__title">Configurazione</h2>
<button
class="drawer__close"
id="settings-close"
type="button"
aria-label="Chiudi"
>
</button>
</div> </div>
<form class="drawer__form"> <form class="drawer__form">
<label class="drawer__field"> <label class="drawer__field">
@@ -128,11 +109,13 @@
/> />
<output class="drawer__value">20</output> <output class="drawer__value">20</output>
</label> </label>
<label for="drawer__field"> <label class="drawer__field">
<span class="drawer__label">Prompts</span> <span class="drawer__label">Prompts</span>
<!-- TODO --> <select class="drawer__select" id="config-prompt">
<!-- Here i need a select --> {% for prompt in prompts %}
<!-- {% for prompt in prompts %} {{ prompt.name }} {{ prompt.id }} {% endfor %} --> <option value="{{ prompt.id }}">{{ prompt.name }}</option>
{% endfor %}
</select>
</label> </label>
</form> </form>
</aside> </aside>