diff --git a/app.py b/app.py index 4b50bee..8b18d7c 100644 --- a/app.py +++ b/app.py @@ -1,13 +1,11 @@ import json -from fastapi import FastAPI, WebSocket -from fastapi.templating import Jinja2Templates -from fastapi import Request +from fastapi import FastAPI, Request, WebSocket from fastapi.staticfiles import StaticFiles +from fastapi.templating import Jinja2Templates -from db import DB from chain import RagChain -from pprint import pprint +from db import DB app = FastAPI() @@ -20,15 +18,9 @@ def home(request: Request): db = DB() prompts = db.get_prompts() - return templates.TemplateResponse("index.html", {"request": request, "prompts": prompts}) - - -@app.get("/prova") -def prova(request: Request): - cursor = DB() - prompts = cursor.get_prompts() - - return prompts + return templates.TemplateResponse( + "index.html", {"request": request, "prompts": prompts} + ) @app.websocket("/ws") @@ -48,14 +40,23 @@ async def websocket_endpoint(websocket: WebSocket): message = raw_message config = {} + db = DB() + prompt_id = int(config.get("prompt_id", 0)) + prompt = db.get_prompt_by_id(prompt_id) + + if prompt is None: + await websocket.send_json( + {"type": "chunk", "data": "Error: prompt not found."} + ) + continue + rag_chain = RagChain( + question_template=str(prompt.text), top_k=int(config.get("top_k", 40)), top_p=float(config.get("top_p", 0.0)), temperature=float(config.get("temperature", 0.0)), - retriever_max_docs=int( - config.get("retriever_max_docs", 40)), - reranker_max_results=int( - config.get("reranker_max_results", 20)), + retriever_max_docs=int(config.get("retriever_max_docs", 40)), + reranker_max_results=int(config.get("reranker_max_results", 20)), ) async for chunk in rag_chain.stream(message): diff --git a/chain.py b/chain.py index 6f641be..cb33661 100644 --- a/chain.py +++ b/chain.py @@ -20,6 +20,7 @@ MAX_OUTPUT_TOKENS = 65535 class RagChain: def __init__( self, + question_template: str, top_k: int, top_p: float, temperature: float, @@ -32,8 +33,8 @@ class RagChain: self.retriever_max_docs = retriever_max_docs self.reranker_max_results = reranker_max_results - with open("prompt.md") as f: - question_template = f.read() + # with open("prompt.md") as f: + # question_template = f.read() with open("question_rewrite_prompt.md") as f: question_rewrite_template = f.read() diff --git a/db.py b/db.py index bc2e1cd..a90759e 100644 --- a/db.py +++ b/db.py @@ -1,4 +1,5 @@ from typing import List + from sqlalchemy import create_engine from sqlalchemy.orm import Session @@ -7,9 +8,12 @@ from models.orm import Prompt class DB: def __init__(self, db: str = "sqlite:///example.db"): - self.engine = create_engine( - db, connect_args={"check_same_thread": False}) + self.engine = create_engine(db, connect_args={"check_same_thread": False}) def get_prompts(self) -> List[Prompt]: with Session(self.engine) as session: return session.query(Prompt).all() + + def get_prompt_by_id(self, prompt_id: int) -> Prompt | None: + with Session(self.engine) as session: + return session.query(Prompt).filter(Prompt.id == prompt_id).first() diff --git a/static/css/chat.css b/static/css/chat.css index 8c2a06d..019d0d1 100644 --- a/static/css/chat.css +++ b/static/css/chat.css @@ -53,21 +53,6 @@ body { gap: 12px; } -.chat__settings { - padding: 6px 12px; - border-radius: 999px; - border: 1px solid #d1d5db; - background: #ffffff; - color: #1f2937; - font-weight: 600; - font-size: 13px; - cursor: pointer; -} - -.chat__settings:hover { - background: #f3f4f6; -} - .chat__messages { flex: 1; padding: 16px 20px; @@ -204,18 +189,11 @@ body { display: flex; flex-direction: column; gap: 16px; - transform: translateX(100%); - transition: transform 0.25s ease; -} - -.drawer--open { - transform: translateX(0%); } .drawer__header { display: flex; align-items: center; - justify-content: space-between; gap: 12px; } @@ -224,14 +202,6 @@ body { font-size: 18px; } -.drawer__close { - border: 1px solid #d1d5db; - background: #ffffff; - border-radius: 8px; - padding: 4px 8px; - cursor: pointer; -} - .drawer__form { display: flex; flex-direction: column; @@ -261,12 +231,3 @@ body { font-variant-numeric: tabular-nums; } -@media (max-width: 900px) { - .drawer { - position: absolute; - right: 0; - top: 0; - height: 100vh; - box-shadow: -12px 0 24px rgba(15, 23, 42, 0.08); - } -} diff --git a/static/js/chat.js b/static/js/chat.js index 889e32e..baf75ba 100644 --- a/static/js/chat.js +++ b/static/js/chat.js @@ -4,9 +4,7 @@ const formEl = document.getElementById("chat-form"); const inputEl = document.getElementById("chat-input"); const clearEl = document.getElementById("chat-clear"); const sendButtonEl = formEl?.querySelector("button[type='submit']"); -const settingsToggle = document.getElementById("settings-toggle"); const settingsDrawer = document.getElementById("settings-drawer"); -const settingsClose = document.getElementById("settings-close"); const rangeInputs = settingsDrawer ? Array.from(settingsDrawer.querySelectorAll("input[type='range']")) : []; @@ -16,6 +14,7 @@ const configInputs = { temperature: document.getElementById("config-temperature"), retrieverMaxDocs: document.getElementById("config-retriever-max-docs"), rerankerMaxDocs: document.getElementById("config-reranker-max-docs"), + promptId: document.getElementById("config-prompt"), }; const scheme = window.location.protocol === "https:" ? "wss" : "ws"; @@ -156,28 +155,6 @@ rangeInputs.forEach((input) => { input.addEventListener("input", () => updateDrawerValue(input)); }); -const closeDrawer = () => { - settingsDrawer?.classList.remove("drawer--open"); - if (settingsDrawer) settingsDrawer.setAttribute("aria-hidden", "true"); - if (settingsToggle) settingsToggle.setAttribute("aria-expanded", "false"); -}; - -const openDrawer = () => { - settingsDrawer?.classList.add("drawer--open"); - if (settingsDrawer) settingsDrawer.setAttribute("aria-hidden", "false"); - if (settingsToggle) settingsToggle.setAttribute("aria-expanded", "true"); -}; - -settingsToggle?.addEventListener("click", () => { - if (settingsDrawer?.classList.contains("drawer--open")) { - closeDrawer(); - } else { - openDrawer(); - } -}); - -settingsClose?.addEventListener("click", closeDrawer); - clearEl?.addEventListener("click", () => { clearMessages(); inputEl?.focus(); @@ -201,6 +178,7 @@ formEl.addEventListener("submit", (event) => { temperature: Number(configInputs.temperature?.value ?? 0), retriever_max_docs: Number(configInputs.retrieverMaxDocs?.value ?? 0), reranker_max_results: Number(configInputs.rerankerMaxDocs?.value ?? 0), + prompt_id: Number(configInputs.promptId?.value ?? 0), }; addMessage(text, "out"); socket.send(JSON.stringify({ message: text, config })); diff --git a/templates/index.html b/templates/index.html index f195548..90eb716 100644 --- a/templates/index.html +++ b/templates/index.html @@ -12,16 +12,6 @@
AKERN Assistant -
Connecting…
@@ -49,18 +39,9 @@ class="drawer" id="settings-drawer" aria-label="Configurazione" - aria-hidden="true" >

Configurazione

-
-