prompt template from db
This commit is contained in:
37
app.py
37
app.py
@@ -1,13 +1,11 @@
|
|||||||
import json
|
import json
|
||||||
|
|
||||||
from fastapi import FastAPI, WebSocket
|
from fastapi import FastAPI, Request, WebSocket
|
||||||
from fastapi.templating import Jinja2Templates
|
|
||||||
from fastapi import Request
|
|
||||||
from fastapi.staticfiles import StaticFiles
|
from fastapi.staticfiles import StaticFiles
|
||||||
|
from fastapi.templating import Jinja2Templates
|
||||||
|
|
||||||
from db import DB
|
|
||||||
from chain import RagChain
|
from chain import RagChain
|
||||||
from pprint import pprint
|
from db import DB
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
@@ -20,15 +18,9 @@ def home(request: Request):
|
|||||||
db = DB()
|
db = DB()
|
||||||
prompts = db.get_prompts()
|
prompts = db.get_prompts()
|
||||||
|
|
||||||
return templates.TemplateResponse("index.html", {"request": request, "prompts": prompts})
|
return templates.TemplateResponse(
|
||||||
|
"index.html", {"request": request, "prompts": prompts}
|
||||||
|
)
|
||||||
@app.get("/prova")
|
|
||||||
def prova(request: Request):
|
|
||||||
cursor = DB()
|
|
||||||
prompts = cursor.get_prompts()
|
|
||||||
|
|
||||||
return prompts
|
|
||||||
|
|
||||||
|
|
||||||
@app.websocket("/ws")
|
@app.websocket("/ws")
|
||||||
@@ -48,14 +40,23 @@ async def websocket_endpoint(websocket: WebSocket):
|
|||||||
message = raw_message
|
message = raw_message
|
||||||
config = {}
|
config = {}
|
||||||
|
|
||||||
|
db = DB()
|
||||||
|
prompt_id = int(config.get("prompt_id", 0))
|
||||||
|
prompt = db.get_prompt_by_id(prompt_id)
|
||||||
|
|
||||||
|
if prompt is None:
|
||||||
|
await websocket.send_json(
|
||||||
|
{"type": "chunk", "data": "Error: prompt not found."}
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
rag_chain = RagChain(
|
rag_chain = RagChain(
|
||||||
|
question_template=str(prompt.text),
|
||||||
top_k=int(config.get("top_k", 40)),
|
top_k=int(config.get("top_k", 40)),
|
||||||
top_p=float(config.get("top_p", 0.0)),
|
top_p=float(config.get("top_p", 0.0)),
|
||||||
temperature=float(config.get("temperature", 0.0)),
|
temperature=float(config.get("temperature", 0.0)),
|
||||||
retriever_max_docs=int(
|
retriever_max_docs=int(config.get("retriever_max_docs", 40)),
|
||||||
config.get("retriever_max_docs", 40)),
|
reranker_max_results=int(config.get("reranker_max_results", 20)),
|
||||||
reranker_max_results=int(
|
|
||||||
config.get("reranker_max_results", 20)),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
async for chunk in rag_chain.stream(message):
|
async for chunk in rag_chain.stream(message):
|
||||||
|
|||||||
5
chain.py
5
chain.py
@@ -20,6 +20,7 @@ MAX_OUTPUT_TOKENS = 65535
|
|||||||
class RagChain:
|
class RagChain:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
question_template: str,
|
||||||
top_k: int,
|
top_k: int,
|
||||||
top_p: float,
|
top_p: float,
|
||||||
temperature: float,
|
temperature: float,
|
||||||
@@ -32,8 +33,8 @@ class RagChain:
|
|||||||
self.retriever_max_docs = retriever_max_docs
|
self.retriever_max_docs = retriever_max_docs
|
||||||
self.reranker_max_results = reranker_max_results
|
self.reranker_max_results = reranker_max_results
|
||||||
|
|
||||||
with open("prompt.md") as f:
|
# with open("prompt.md") as f:
|
||||||
question_template = f.read()
|
# question_template = f.read()
|
||||||
|
|
||||||
with open("question_rewrite_prompt.md") as f:
|
with open("question_rewrite_prompt.md") as f:
|
||||||
question_rewrite_template = f.read()
|
question_rewrite_template = f.read()
|
||||||
|
|||||||
8
db.py
8
db.py
@@ -1,4 +1,5 @@
|
|||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from sqlalchemy import create_engine
|
from sqlalchemy import create_engine
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
@@ -7,9 +8,12 @@ from models.orm import Prompt
|
|||||||
|
|
||||||
class DB:
|
class DB:
|
||||||
def __init__(self, db: str = "sqlite:///example.db"):
|
def __init__(self, db: str = "sqlite:///example.db"):
|
||||||
self.engine = create_engine(
|
self.engine = create_engine(db, connect_args={"check_same_thread": False})
|
||||||
db, connect_args={"check_same_thread": False})
|
|
||||||
|
|
||||||
def get_prompts(self) -> List[Prompt]:
|
def get_prompts(self) -> List[Prompt]:
|
||||||
with Session(self.engine) as session:
|
with Session(self.engine) as session:
|
||||||
return session.query(Prompt).all()
|
return session.query(Prompt).all()
|
||||||
|
|
||||||
|
def get_prompt_by_id(self, prompt_id: int) -> Prompt | None:
|
||||||
|
with Session(self.engine) as session:
|
||||||
|
return session.query(Prompt).filter(Prompt.id == prompt_id).first()
|
||||||
|
|||||||
@@ -53,21 +53,6 @@ body {
|
|||||||
gap: 12px;
|
gap: 12px;
|
||||||
}
|
}
|
||||||
|
|
||||||
.chat__settings {
|
|
||||||
padding: 6px 12px;
|
|
||||||
border-radius: 999px;
|
|
||||||
border: 1px solid #d1d5db;
|
|
||||||
background: #ffffff;
|
|
||||||
color: #1f2937;
|
|
||||||
font-weight: 600;
|
|
||||||
font-size: 13px;
|
|
||||||
cursor: pointer;
|
|
||||||
}
|
|
||||||
|
|
||||||
.chat__settings:hover {
|
|
||||||
background: #f3f4f6;
|
|
||||||
}
|
|
||||||
|
|
||||||
.chat__messages {
|
.chat__messages {
|
||||||
flex: 1;
|
flex: 1;
|
||||||
padding: 16px 20px;
|
padding: 16px 20px;
|
||||||
@@ -204,18 +189,11 @@ body {
|
|||||||
display: flex;
|
display: flex;
|
||||||
flex-direction: column;
|
flex-direction: column;
|
||||||
gap: 16px;
|
gap: 16px;
|
||||||
transform: translateX(100%);
|
|
||||||
transition: transform 0.25s ease;
|
|
||||||
}
|
|
||||||
|
|
||||||
.drawer--open {
|
|
||||||
transform: translateX(0%);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
.drawer__header {
|
.drawer__header {
|
||||||
display: flex;
|
display: flex;
|
||||||
align-items: center;
|
align-items: center;
|
||||||
justify-content: space-between;
|
|
||||||
gap: 12px;
|
gap: 12px;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -224,14 +202,6 @@ body {
|
|||||||
font-size: 18px;
|
font-size: 18px;
|
||||||
}
|
}
|
||||||
|
|
||||||
.drawer__close {
|
|
||||||
border: 1px solid #d1d5db;
|
|
||||||
background: #ffffff;
|
|
||||||
border-radius: 8px;
|
|
||||||
padding: 4px 8px;
|
|
||||||
cursor: pointer;
|
|
||||||
}
|
|
||||||
|
|
||||||
.drawer__form {
|
.drawer__form {
|
||||||
display: flex;
|
display: flex;
|
||||||
flex-direction: column;
|
flex-direction: column;
|
||||||
@@ -261,12 +231,3 @@ body {
|
|||||||
font-variant-numeric: tabular-nums;
|
font-variant-numeric: tabular-nums;
|
||||||
}
|
}
|
||||||
|
|
||||||
@media (max-width: 900px) {
|
|
||||||
.drawer {
|
|
||||||
position: absolute;
|
|
||||||
right: 0;
|
|
||||||
top: 0;
|
|
||||||
height: 100vh;
|
|
||||||
box-shadow: -12px 0 24px rgba(15, 23, 42, 0.08);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -4,9 +4,7 @@ const formEl = document.getElementById("chat-form");
|
|||||||
const inputEl = document.getElementById("chat-input");
|
const inputEl = document.getElementById("chat-input");
|
||||||
const clearEl = document.getElementById("chat-clear");
|
const clearEl = document.getElementById("chat-clear");
|
||||||
const sendButtonEl = formEl?.querySelector("button[type='submit']");
|
const sendButtonEl = formEl?.querySelector("button[type='submit']");
|
||||||
const settingsToggle = document.getElementById("settings-toggle");
|
|
||||||
const settingsDrawer = document.getElementById("settings-drawer");
|
const settingsDrawer = document.getElementById("settings-drawer");
|
||||||
const settingsClose = document.getElementById("settings-close");
|
|
||||||
const rangeInputs = settingsDrawer
|
const rangeInputs = settingsDrawer
|
||||||
? Array.from(settingsDrawer.querySelectorAll("input[type='range']"))
|
? Array.from(settingsDrawer.querySelectorAll("input[type='range']"))
|
||||||
: [];
|
: [];
|
||||||
@@ -16,6 +14,7 @@ const configInputs = {
|
|||||||
temperature: document.getElementById("config-temperature"),
|
temperature: document.getElementById("config-temperature"),
|
||||||
retrieverMaxDocs: document.getElementById("config-retriever-max-docs"),
|
retrieverMaxDocs: document.getElementById("config-retriever-max-docs"),
|
||||||
rerankerMaxDocs: document.getElementById("config-reranker-max-docs"),
|
rerankerMaxDocs: document.getElementById("config-reranker-max-docs"),
|
||||||
|
promptId: document.getElementById("config-prompt"),
|
||||||
};
|
};
|
||||||
|
|
||||||
const scheme = window.location.protocol === "https:" ? "wss" : "ws";
|
const scheme = window.location.protocol === "https:" ? "wss" : "ws";
|
||||||
@@ -156,28 +155,6 @@ rangeInputs.forEach((input) => {
|
|||||||
input.addEventListener("input", () => updateDrawerValue(input));
|
input.addEventListener("input", () => updateDrawerValue(input));
|
||||||
});
|
});
|
||||||
|
|
||||||
const closeDrawer = () => {
|
|
||||||
settingsDrawer?.classList.remove("drawer--open");
|
|
||||||
if (settingsDrawer) settingsDrawer.setAttribute("aria-hidden", "true");
|
|
||||||
if (settingsToggle) settingsToggle.setAttribute("aria-expanded", "false");
|
|
||||||
};
|
|
||||||
|
|
||||||
const openDrawer = () => {
|
|
||||||
settingsDrawer?.classList.add("drawer--open");
|
|
||||||
if (settingsDrawer) settingsDrawer.setAttribute("aria-hidden", "false");
|
|
||||||
if (settingsToggle) settingsToggle.setAttribute("aria-expanded", "true");
|
|
||||||
};
|
|
||||||
|
|
||||||
settingsToggle?.addEventListener("click", () => {
|
|
||||||
if (settingsDrawer?.classList.contains("drawer--open")) {
|
|
||||||
closeDrawer();
|
|
||||||
} else {
|
|
||||||
openDrawer();
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
settingsClose?.addEventListener("click", closeDrawer);
|
|
||||||
|
|
||||||
clearEl?.addEventListener("click", () => {
|
clearEl?.addEventListener("click", () => {
|
||||||
clearMessages();
|
clearMessages();
|
||||||
inputEl?.focus();
|
inputEl?.focus();
|
||||||
@@ -201,6 +178,7 @@ formEl.addEventListener("submit", (event) => {
|
|||||||
temperature: Number(configInputs.temperature?.value ?? 0),
|
temperature: Number(configInputs.temperature?.value ?? 0),
|
||||||
retriever_max_docs: Number(configInputs.retrieverMaxDocs?.value ?? 0),
|
retriever_max_docs: Number(configInputs.retrieverMaxDocs?.value ?? 0),
|
||||||
reranker_max_results: Number(configInputs.rerankerMaxDocs?.value ?? 0),
|
reranker_max_results: Number(configInputs.rerankerMaxDocs?.value ?? 0),
|
||||||
|
prompt_id: Number(configInputs.promptId?.value ?? 0),
|
||||||
};
|
};
|
||||||
addMessage(text, "out");
|
addMessage(text, "out");
|
||||||
socket.send(JSON.stringify({ message: text, config }));
|
socket.send(JSON.stringify({ message: text, config }));
|
||||||
|
|||||||
@@ -12,16 +12,6 @@
|
|||||||
<section class="chat" aria-label="Chat">
|
<section class="chat" aria-label="Chat">
|
||||||
<div class="chat__header">
|
<div class="chat__header">
|
||||||
<span>AKERN Assistant</span>
|
<span>AKERN Assistant</span>
|
||||||
<button
|
|
||||||
class="chat__settings"
|
|
||||||
id="settings-toggle"
|
|
||||||
type="button"
|
|
||||||
aria-haspopup="dialog"
|
|
||||||
aria-expanded="false"
|
|
||||||
aria-controls="settings-drawer"
|
|
||||||
>
|
|
||||||
Configurazione
|
|
||||||
</button>
|
|
||||||
</div>
|
</div>
|
||||||
<div class="chat__messages" id="messages"></div>
|
<div class="chat__messages" id="messages"></div>
|
||||||
<div class="chat__status" id="status">Connecting…</div>
|
<div class="chat__status" id="status">Connecting…</div>
|
||||||
@@ -49,18 +39,9 @@
|
|||||||
class="drawer"
|
class="drawer"
|
||||||
id="settings-drawer"
|
id="settings-drawer"
|
||||||
aria-label="Configurazione"
|
aria-label="Configurazione"
|
||||||
aria-hidden="true"
|
|
||||||
>
|
>
|
||||||
<div class="drawer__header">
|
<div class="drawer__header">
|
||||||
<h2 class="drawer__title">Configurazione</h2>
|
<h2 class="drawer__title">Configurazione</h2>
|
||||||
<button
|
|
||||||
class="drawer__close"
|
|
||||||
id="settings-close"
|
|
||||||
type="button"
|
|
||||||
aria-label="Chiudi"
|
|
||||||
>
|
|
||||||
✕
|
|
||||||
</button>
|
|
||||||
</div>
|
</div>
|
||||||
<form class="drawer__form">
|
<form class="drawer__form">
|
||||||
<label class="drawer__field">
|
<label class="drawer__field">
|
||||||
@@ -128,11 +109,13 @@
|
|||||||
/>
|
/>
|
||||||
<output class="drawer__value">20</output>
|
<output class="drawer__value">20</output>
|
||||||
</label>
|
</label>
|
||||||
<label for="drawer__field">
|
<label class="drawer__field">
|
||||||
<span class="drawer__label">Prompts</span>
|
<span class="drawer__label">Prompts</span>
|
||||||
<!-- TODO -->
|
<select class="drawer__select" id="config-prompt">
|
||||||
<!-- Here i need a select -->
|
{% for prompt in prompts %}
|
||||||
<!-- {% for prompt in prompts %} {{ prompt.name }} {{ prompt.id }} {% endfor %} -->
|
<option value="{{ prompt.id }}">{{ prompt.name }}</option>
|
||||||
|
{% endfor %}
|
||||||
|
</select>
|
||||||
</label>
|
</label>
|
||||||
</form>
|
</form>
|
||||||
</aside>
|
</aside>
|
||||||
|
|||||||
Reference in New Issue
Block a user