diff --git a/.gitignore b/.gitignore index a406423..1c28ce1 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,4 @@ credentials.json .zed .DS_STORE *.sqlite3 +*.db \ No newline at end of file diff --git a/app.py b/app.py index e4c988a..4b50bee 100644 --- a/app.py +++ b/app.py @@ -5,6 +5,7 @@ from fastapi.templating import Jinja2Templates from fastapi import Request from fastapi.staticfiles import StaticFiles +from db import DB from chain import RagChain from pprint import pprint @@ -16,7 +17,18 @@ app.mount("/static", StaticFiles(directory="static"), name="static") @app.get("/") def home(request: Request): - return templates.TemplateResponse("index.html", {"request": request}) + db = DB() + prompts = db.get_prompts() + + return templates.TemplateResponse("index.html", {"request": request, "prompts": prompts}) + + +@app.get("/prova") +def prova(request: Request): + cursor = DB() + prompts = cursor.get_prompts() + + return prompts @app.websocket("/ws") @@ -40,8 +52,10 @@ async def websocket_endpoint(websocket: WebSocket): top_k=int(config.get("top_k", 40)), top_p=float(config.get("top_p", 0.0)), temperature=float(config.get("temperature", 0.0)), - retriever_max_docs=int(config.get("retriever_max_docs", 40)), - reranker_max_results=int(config.get("reranker_max_results", 20)), + retriever_max_docs=int( + config.get("retriever_max_docs", 40)), + reranker_max_results=int( + config.get("reranker_max_results", 20)), ) async for chunk in rag_chain.stream(message): @@ -52,6 +66,7 @@ async def websocket_endpoint(websocket: WebSocket): "type": "report", "sources": rag_chain.getSources(), "reranked_sources": rag_chain.getRankedSources(), + "rephrased_question": rag_chain.getRephrasedQuestion(), } ) finally: diff --git a/chain.py b/chain.py index e99da77..6f641be 100644 --- a/chain.py +++ b/chain.py @@ -45,6 +45,7 @@ class RagChain: self._retriever_sources: list[dict] = [] self._reranked_sources: list[dict] = [] + self._rephrased_question: str = "" self._llm = ChatGoogleGenerativeAI( model=MODEL, @@ -80,7 +81,7 @@ class RagChain: | question_rewrite_prompt | self._llm | StrOutputParser() - | RunnableLambda(self._log_rewritten_question) + | RunnableLambda(self._log_rephrased_question) ) rag_chain = ( @@ -95,8 +96,9 @@ class RagChain: self._full_chain = question_rewrite_chain | rag_chain - def _log_rewritten_question(self, rewritten_question: str) -> str: - return rewritten_question + def _log_rephrased_question(self, rephrased_question: str) -> str: + self._rephrased_question = rephrased_question + return rephrased_question def _format_docs(self, question: str) -> str: retrieved_docs = self._base_retriever.invoke(question) @@ -150,5 +152,8 @@ class RagChain: def getRankedSources(self) -> list[dict]: return list(self._reranked_sources) + def getRephrasedQuestion(self) -> str: + return self._rephrased_question + def stream(self, message: str): return self._full_chain.astream(message) diff --git a/db.py b/db.py new file mode 100644 index 0000000..bc2e1cd --- /dev/null +++ b/db.py @@ -0,0 +1,15 @@ +from typing import List +from sqlalchemy import create_engine +from sqlalchemy.orm import Session + +from models.orm import Prompt + + +class DB: + def __init__(self, db: str = "sqlite:///example.db"): + self.engine = create_engine( + db, connect_args={"check_same_thread": False}) + + def get_prompts(self) -> List[Prompt]: + with Session(self.engine) as session: + return session.query(Prompt).all() diff --git a/models/orm.py b/models/orm.py new file mode 100644 index 0000000..be223f4 --- /dev/null +++ b/models/orm.py @@ -0,0 +1,17 @@ +from sqlalchemy import create_engine, Column, String, Text, Integer +from sqlalchemy.orm import declarative_base + +Base = declarative_base() + + +class Prompt(Base): + __tablename__ = "prompts" + + id = Column(Integer, primary_key=True) + name = Column(String, nullable=False) + text = Column(Text, nullable=False) + + +if __name__ == "__main__": + engine = create_engine("sqlite:///example.db") + Base.metadata.create_all(engine) diff --git a/models/validation.py b/models/validation.py new file mode 100644 index 0000000..91d1bf8 --- /dev/null +++ b/models/validation.py @@ -0,0 +1,6 @@ +from pydantic import BaseModel + + +class Prompt(BaseModel): + name: str + text: str diff --git a/pyproject.toml b/pyproject.toml index 6dc12b0..73f58d3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,6 +16,8 @@ dependencies = [ "python-dotenv>=1.2.1", "fastapi>=0.129.0", "fastapi[standard]", + "pydantic>=2.12.5", + "sqlalchemy>=2.0.46", ] [dependency-groups] diff --git a/requirements.txt b/requirements.txt index 2cf3edc..2a32256 100644 --- a/requirements.txt +++ b/requirements.txt @@ -75,7 +75,6 @@ mdurl==0.1.2 multidict==6.7.1 mypy-extensions==1.1.0 nodeenv==1.10.0 -nuitka==4.0.1 numexpr==2.14.1 numpy==2.4.2 openai==2.21.0 diff --git a/static/css/chat.css b/static/css/chat.css index b8e9efa..8c2a06d 100644 --- a/static/css/chat.css +++ b/static/css/chat.css @@ -1,274 +1,272 @@ :root { - color-scheme: light; - font-family: system-ui, -apple-system, BlinkMacSystemFont, "Segoe UI", sans-serif; + color-scheme: light; + font-family: + system-ui, + -apple-system, + BlinkMacSystemFont, + "Segoe UI", + sans-serif; } * { - box-sizing: border-box; + box-sizing: border-box; } html, body { - height: 100%; + height: 100%; } body { - margin: 0; - padding: 0; - background: #f7f8fa; - color: #1f2937; + margin: 0; + padding: 0; + background: #f7f8fa; + color: #1f2937; } .app { - display: flex; - width: 100vw; - height: 100vh; - overflow: hidden; + display: flex; + width: 100vw; + height: 100vh; + overflow: hidden; } .chat { - flex: 1; - height: 100%; - margin: 0; - background: #ffffff; - border: 1px solid #e5e7eb; - border-radius: 0; - display: flex; - flex-direction: column; + flex: 1; + height: 100%; + margin: 0; + background: #ffffff; + border: 1px solid #e5e7eb; + border-radius: 0; + display: flex; + flex-direction: column; } .chat__header { - padding: 16px 20px; - border-bottom: 1px solid #e5e7eb; - font-weight: 600; - background: #f9fafb; - display: flex; - align-items: center; - justify-content: space-between; - gap: 12px; + padding: 16px 20px; + border-bottom: 1px solid #e5e7eb; + font-weight: 600; + background: #f9fafb; + display: flex; + align-items: center; + justify-content: space-between; + gap: 12px; } .chat__settings { - padding: 6px 12px; - border-radius: 999px; - border: 1px solid #d1d5db; - background: #ffffff; - color: #1f2937; - font-weight: 600; - font-size: 13px; - cursor: pointer; + padding: 6px 12px; + border-radius: 999px; + border: 1px solid #d1d5db; + background: #ffffff; + color: #1f2937; + font-weight: 600; + font-size: 13px; + cursor: pointer; } .chat__settings:hover { - background: #f3f4f6; + background: #f3f4f6; } .chat__messages { - flex: 1; - padding: 16px 20px; - overflow-y: auto; - display: flex; - flex-direction: column; - gap: 10px; + flex: 1; + padding: 16px 20px; + overflow-y: auto; + display: flex; + flex-direction: column; + gap: 10px; } .message { - max-width: 70%; - padding: 10px 12px; - border-radius: 12px; - line-height: 1.4; - font-size: 14px; - word-wrap: break-word; - white-space: pre-wrap; + max-width: 70%; + padding: 10px 12px; + border-radius: 12px; + line-height: 1.4; + font-size: 14px; + word-wrap: break-word; + white-space: pre-wrap; } .message--out { - align-self: flex-end; - background: #2563eb; - color: #ffffff; + align-self: flex-end; + background: #2563eb; + color: #ffffff; } .message--in { - align-self: flex-start; - background: #f3f4f6; - color: #111827; + align-self: flex-start; + background: #f3f4f6; + color: #111827; } /* Typing / loading indicator */ .message--loading { - display: flex; - align-items: center; - gap: 5px; - padding: 12px 16px; + display: flex; + align-items: center; + gap: 5px; + padding: 12px 16px; } .typing-dot { - width: 8px; - height: 8px; - border-radius: 50%; - background: #9ca3af; - animation: typing-bounce 1.2s infinite ease-in-out; + width: 8px; + height: 8px; + border-radius: 50%; + background: #9ca3af; + animation: typing-bounce 1.2s infinite ease-in-out; } .typing-dot:nth-child(1) { - animation-delay: 0s; + animation-delay: 0s; } .typing-dot:nth-child(2) { - animation-delay: 0.2s; + animation-delay: 0.2s; } .typing-dot:nth-child(3) { - animation-delay: 0.4s; + animation-delay: 0.4s; } @keyframes typing-bounce { + 0%, + 60%, + 100% { + transform: translateY(0); + background: #9ca3af; + } - 0%, - 60%, - 100% { - transform: translateY(0); - background: #9ca3af; - } - - 30% { - transform: translateY(-6px); - background: #6b7280; - } + 30% { + transform: translateY(-6px); + background: #6b7280; + } } .chat__footer { - padding: 16px 20px; - border-top: 1px solid #e5e7eb; - display: flex; - gap: 12px; - background: #f9fafb; + padding: 16px 20px; + border-top: 1px solid #e5e7eb; + display: flex; + gap: 12px; + background: #f9fafb; } .chat__input { - flex: 1; - padding: 10px 12px; - border-radius: 8px; - border: 1px solid #d1d5db; - background: #ffffff; - color: #111827; + flex: 1; + padding: 10px 12px; + border-radius: 8px; + border: 1px solid #d1d5db; + background: #ffffff; + color: #111827; } .chat__input::placeholder { - color: #9ca3af; + color: #9ca3af; } .chat__button { - padding: 10px 16px; - border-radius: 8px; - border: none; - background: #2563eb; - color: #ffffff; - font-weight: 600; - cursor: pointer; + padding: 10px 16px; + border-radius: 8px; + border: none; + background: #2563eb; + color: #ffffff; + font-weight: 600; + cursor: pointer; } .chat__button:disabled { - background: #cbd5f5; - color: #ffffff; - cursor: not-allowed; - opacity: 0.75; + background: #cbd5f5; + color: #ffffff; + cursor: not-allowed; + opacity: 0.75; } .chat__button--secondary { - background: #e5e7eb; - color: #1f2937; + background: #e5e7eb; + color: #1f2937; } .chat__button--secondary:hover { - background: #d1d5db; + background: #d1d5db; } .chat__status { - padding: 8px 20px 0; - font-size: 12px; - color: #6b7280; + padding: 8px 20px 0; + font-size: 12px; + color: #6b7280; } .drawer { - width: 320px; - max-width: 100%; - height: 100%; - border-left: 1px solid #e5e7eb; - background: #ffffff; - padding: 20px; - display: flex; - flex-direction: column; - gap: 16px; - transform: translateX(100%); - transition: transform 0.25s ease; + width: 320px; + max-width: 100%; + height: 100%; + border-left: 1px solid #e5e7eb; + background: #ffffff; + padding: 20px; + display: flex; + flex-direction: column; + gap: 16px; + transform: translateX(100%); + transition: transform 0.25s ease; } .drawer--open { - transform: translateX(0%); + transform: translateX(0%); } .drawer__header { - display: flex; - align-items: center; - justify-content: space-between; - gap: 12px; + display: flex; + align-items: center; + justify-content: space-between; + gap: 12px; } .drawer__title { - margin: 0; - font-size: 18px; + margin: 0; + font-size: 18px; } .drawer__close { - border: 1px solid #d1d5db; - background: #ffffff; - border-radius: 8px; - padding: 4px 8px; - cursor: pointer; + border: 1px solid #d1d5db; + background: #ffffff; + border-radius: 8px; + padding: 4px 8px; + cursor: pointer; } .drawer__form { - display: flex; - flex-direction: column; - gap: 16px; + display: flex; + flex-direction: column; + gap: 16px; } .drawer__field { - display: grid; - grid-template-columns: 1fr auto; - grid-template-rows: auto auto; - gap: 6px 12px; - align-items: center; - font-size: 13px; - color: #1f2937; + display: grid; + grid-template-columns: 1fr auto; + grid-template-rows: auto auto; + gap: 6px 12px; + align-items: center; + font-size: 13px; + color: #1f2937; } .drawer__label { - font-weight: 600; + font-weight: 600; } .drawer__range { - grid-column: 1 / -1; - width: 100%; + grid-column: 1 / -1; + width: 100%; } .drawer__value { - font-variant-numeric: tabular-nums; -} - -.drawer__hint { - margin: 0; - font-size: 12px; - color: #6b7280; + font-variant-numeric: tabular-nums; } @media (max-width: 900px) { - .drawer { - position: absolute; - right: 0; - top: 0; - height: 100vh; - box-shadow: -12px 0 24px rgba(15, 23, 42, 0.08); - } -} \ No newline at end of file + .drawer { + position: absolute; + right: 0; + top: 0; + height: 100vh; + box-shadow: -12px 0 24px rgba(15, 23, 42, 0.08); + } +} diff --git a/static/js/chat.js b/static/js/chat.js index 0e7f754..889e32e 100644 --- a/static/js/chat.js +++ b/static/js/chat.js @@ -8,14 +8,14 @@ const settingsToggle = document.getElementById("settings-toggle"); const settingsDrawer = document.getElementById("settings-drawer"); const settingsClose = document.getElementById("settings-close"); const rangeInputs = settingsDrawer - ? Array.from(settingsDrawer.querySelectorAll("input[type='range']")) - : []; + ? Array.from(settingsDrawer.querySelectorAll("input[type='range']")) + : []; const configInputs = { - topK: document.getElementById("config-top-k"), - topP: document.getElementById("config-top-p"), - temperature: document.getElementById("config-temperature"), - retrieverMaxDocs: document.getElementById("config-retriever-max-docs"), - rerankerMaxDocs: document.getElementById("config-reranker-max-docs"), + topK: document.getElementById("config-top-k"), + topP: document.getElementById("config-top-p"), + temperature: document.getElementById("config-temperature"), + retrieverMaxDocs: document.getElementById("config-retriever-max-docs"), + rerankerMaxDocs: document.getElementById("config-reranker-max-docs"), }; const scheme = window.location.protocol === "https:" ? "wss" : "ws"; @@ -26,178 +26,189 @@ let loadingBubble = null; let isAwaitingResponse = false; const updateSendButtonState = () => { - if (!sendButtonEl) return; - sendButtonEl.disabled = isAwaitingResponse; + if (!sendButtonEl) return; + sendButtonEl.disabled = isAwaitingResponse; }; const addMessage = (text, direction) => { - const bubble = document.createElement("div"); - bubble.className = `message message--${direction}`; - bubble.textContent = text; - messagesEl.appendChild(bubble); - messagesEl.scrollTop = messagesEl.scrollHeight; + const bubble = document.createElement("div"); + bubble.className = `message message--${direction}`; + bubble.textContent = text; + messagesEl.appendChild(bubble); + messagesEl.scrollTop = messagesEl.scrollHeight; }; const showLoadingBubble = () => { - if (loadingBubble) return; - loadingBubble = document.createElement("div"); - loadingBubble.className = "message message--in message--loading"; - loadingBubble.innerHTML = - ''; - messagesEl.appendChild(loadingBubble); - messagesEl.scrollTop = messagesEl.scrollHeight; + if (loadingBubble) return; + loadingBubble = document.createElement("div"); + loadingBubble.className = "message message--in message--loading"; + loadingBubble.innerHTML = + ''; + messagesEl.appendChild(loadingBubble); + messagesEl.scrollTop = messagesEl.scrollHeight; }; const removeLoadingBubble = () => { - if (!loadingBubble) return; - loadingBubble.remove(); - loadingBubble = null; + if (!loadingBubble) return; + loadingBubble.remove(); + loadingBubble = null; }; const clearMessages = () => { - if (!messagesEl) return; - messagesEl.innerHTML = ""; - streamingBubble = null; - loadingBubble = null; + if (!messagesEl) return; + messagesEl.innerHTML = ""; + streamingBubble = null; + loadingBubble = null; }; const connect = () => { - socket = new WebSocket(socketUrl); + socket = new WebSocket(socketUrl); - socket.addEventListener("open", () => { - statusEl.textContent = "Connected"; - }); + socket.addEventListener("open", () => { + statusEl.textContent = "Connected"; + }); - socket.addEventListener("message", (event) => { - let payload; - try { - payload = JSON.parse(event.data); - } catch (error) { - addMessage(event.data, "in"); - return; - } + socket.addEventListener("message", (event) => { + let payload; + try { + payload = JSON.parse(event.data); + } catch (error) { + addMessage(event.data, "in"); + return; + } - if (payload.type === "start") { - // Keep the loading bubble visible until the first chunk arrives. - // Just prepare the streaming bubble but don't append it yet. - streamingBubble = document.createElement("div"); - streamingBubble.className = "message message--in"; - streamingBubble.textContent = ""; - return; - } + if (payload.type === "start") { + // Keep the loading bubble visible until the first chunk arrives. + // Just prepare the streaming bubble but don't append it yet. + streamingBubble = document.createElement("div"); + streamingBubble.className = "message message--in"; + streamingBubble.textContent = ""; + return; + } - if (payload.type === "chunk") { - if (!streamingBubble) { - streamingBubble = document.createElement("div"); - streamingBubble.className = "message message--in"; - } - // First chunk: swap loading bubble for the streaming bubble. - if (!streamingBubble.isConnected) { - removeLoadingBubble(); - messagesEl.appendChild(streamingBubble); - } - streamingBubble.textContent += payload.data ?? ""; - messagesEl.scrollTop = messagesEl.scrollHeight; - return; - } - - if (payload.type === "report") { - console.log("%cSOURCES", 'border: 2px solid red; padding: 1em; color: red; font-size: 14px; font-weight: bold;') - console.log(payload.sources); - console.log("%cRE-RANKED SOURCES", 'border: 2px solid red; padding: 1em; color: red; font-size: 14px; font-weight: bold;') - console.log(payload.reranked_sources); - return; - } - - if (payload.type === "end") { - // Safety net: remove loading bubble if no chunks were ever received. - removeLoadingBubble(); - streamingBubble = null; - isAwaitingResponse = false; - updateSendButtonState(); - return; - } - }); - - socket.addEventListener("close", () => { - statusEl.textContent = "Disconnected"; + if (payload.type === "chunk") { + if (!streamingBubble) { + streamingBubble = document.createElement("div"); + streamingBubble.className = "message message--in"; + } + // First chunk: swap loading bubble for the streaming bubble. + if (!streamingBubble.isConnected) { removeLoadingBubble(); - isAwaitingResponse = false; - updateSendButtonState(); - }); + messagesEl.appendChild(streamingBubble); + } + streamingBubble.textContent += payload.data ?? ""; + messagesEl.scrollTop = messagesEl.scrollHeight; + return; + } - socket.addEventListener("error", () => { - statusEl.textContent = "Connection error"; - removeLoadingBubble(); - isAwaitingResponse = false; - updateSendButtonState(); - }); + if (payload.type === "report") { + console.log( + "%cREPHRASED QUESTION", + "border: 2px solid red; padding: 1em; color: red; font-size: 14px; font-weight: bold;", + ); + console.log(payload.rephrased_question); + console.log( + "%cSOURCES", + "border: 2px solid red; padding: 1em; color: red; font-size: 14px; font-weight: bold;", + ); + console.log(payload.sources); + console.log( + "%cRE-RANKED SOURCES", + "border: 2px solid red; padding: 1em; color: red; font-size: 14px; font-weight: bold;", + ); + console.log(payload.reranked_sources); + return; + } + + if (payload.type === "end") { + // Safety net: remove loading bubble if no chunks were ever received. + removeLoadingBubble(); + streamingBubble = null; + isAwaitingResponse = false; + updateSendButtonState(); + return; + } + }); + + socket.addEventListener("close", () => { + statusEl.textContent = "Disconnected"; + removeLoadingBubble(); + isAwaitingResponse = false; + updateSendButtonState(); + }); + + socket.addEventListener("error", () => { + statusEl.textContent = "Connection error"; + removeLoadingBubble(); + isAwaitingResponse = false; + updateSendButtonState(); + }); }; const updateDrawerValue = (input) => { - const output = input.parentElement?.querySelector(".drawer__value"); - if (!output) return; - output.textContent = input.value; + const output = input.parentElement?.querySelector(".drawer__value"); + if (!output) return; + output.textContent = input.value; }; rangeInputs.forEach((input) => { - updateDrawerValue(input); - input.addEventListener("input", () => updateDrawerValue(input)); + updateDrawerValue(input); + input.addEventListener("input", () => updateDrawerValue(input)); }); const closeDrawer = () => { - settingsDrawer?.classList.remove("drawer--open"); - if (settingsDrawer) settingsDrawer.setAttribute("aria-hidden", "true"); - if (settingsToggle) settingsToggle.setAttribute("aria-expanded", "false"); + settingsDrawer?.classList.remove("drawer--open"); + if (settingsDrawer) settingsDrawer.setAttribute("aria-hidden", "true"); + if (settingsToggle) settingsToggle.setAttribute("aria-expanded", "false"); }; const openDrawer = () => { - settingsDrawer?.classList.add("drawer--open"); - if (settingsDrawer) settingsDrawer.setAttribute("aria-hidden", "false"); - if (settingsToggle) settingsToggle.setAttribute("aria-expanded", "true"); + settingsDrawer?.classList.add("drawer--open"); + if (settingsDrawer) settingsDrawer.setAttribute("aria-hidden", "false"); + if (settingsToggle) settingsToggle.setAttribute("aria-expanded", "true"); }; settingsToggle?.addEventListener("click", () => { - if (settingsDrawer?.classList.contains("drawer--open")) { - closeDrawer(); - } else { - openDrawer(); - } + if (settingsDrawer?.classList.contains("drawer--open")) { + closeDrawer(); + } else { + openDrawer(); + } }); settingsClose?.addEventListener("click", closeDrawer); clearEl?.addEventListener("click", () => { - clearMessages(); - inputEl?.focus(); - updateSendButtonState(); + clearMessages(); + inputEl?.focus(); + updateSendButtonState(); }); inputEl?.addEventListener("input", updateSendButtonState); formEl.addEventListener("submit", (event) => { - event.preventDefault(); - const text = inputEl.value.trim(); - if (!text) return; - if (!socket || socket.readyState !== WebSocket.OPEN) { - addMessage("Not connected.", "in"); - return; - } - if (isAwaitingResponse) return; - const config = { - top_k: Number(configInputs.topK?.value ?? 0), - top_p: Number(configInputs.topP?.value ?? 0), - temperature: Number(configInputs.temperature?.value ?? 0), - retriever_max_docs: Number(configInputs.retrieverMaxDocs?.value ?? 0), - reranker_max_results: Number(configInputs.rerankerMaxDocs?.value ?? 0), - }; - addMessage(text, "out"); - socket.send(JSON.stringify({ message: text, config })); - inputEl.value = ""; - inputEl.focus(); - isAwaitingResponse = true; - updateSendButtonState(); - showLoadingBubble(); + event.preventDefault(); + const text = inputEl.value.trim(); + if (!text) return; + if (!socket || socket.readyState !== WebSocket.OPEN) { + addMessage("Not connected.", "in"); + return; + } + if (isAwaitingResponse) return; + const config = { + top_k: Number(configInputs.topK?.value ?? 0), + top_p: Number(configInputs.topP?.value ?? 0), + temperature: Number(configInputs.temperature?.value ?? 0), + retriever_max_docs: Number(configInputs.retrieverMaxDocs?.value ?? 0), + reranker_max_results: Number(configInputs.rerankerMaxDocs?.value ?? 0), + }; + addMessage(text, "out"); + socket.send(JSON.stringify({ message: text, config })); + inputEl.value = ""; + inputEl.focus(); + isAwaitingResponse = true; + updateSendButtonState(); + showLoadingBubble(); }); connect(); diff --git a/templates/index.html b/templates/index.html index c6477f7..f195548 100644 --- a/templates/index.html +++ b/templates/index.html @@ -1,72 +1,143 @@ - + - -
+