prompt template from db
This commit is contained in:
37
app.py
37
app.py
@@ -1,13 +1,11 @@
|
||||
import json
|
||||
|
||||
from fastapi import FastAPI, WebSocket
|
||||
from fastapi.templating import Jinja2Templates
|
||||
from fastapi import Request
|
||||
from fastapi import FastAPI, Request, WebSocket
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi.templating import Jinja2Templates
|
||||
|
||||
from db import DB
|
||||
from chain import RagChain
|
||||
from pprint import pprint
|
||||
from db import DB
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
@@ -20,15 +18,9 @@ def home(request: Request):
|
||||
db = DB()
|
||||
prompts = db.get_prompts()
|
||||
|
||||
return templates.TemplateResponse("index.html", {"request": request, "prompts": prompts})
|
||||
|
||||
|
||||
@app.get("/prova")
|
||||
def prova(request: Request):
|
||||
cursor = DB()
|
||||
prompts = cursor.get_prompts()
|
||||
|
||||
return prompts
|
||||
return templates.TemplateResponse(
|
||||
"index.html", {"request": request, "prompts": prompts}
|
||||
)
|
||||
|
||||
|
||||
@app.websocket("/ws")
|
||||
@@ -48,14 +40,23 @@ async def websocket_endpoint(websocket: WebSocket):
|
||||
message = raw_message
|
||||
config = {}
|
||||
|
||||
db = DB()
|
||||
prompt_id = int(config.get("prompt_id", 0))
|
||||
prompt = db.get_prompt_by_id(prompt_id)
|
||||
|
||||
if prompt is None:
|
||||
await websocket.send_json(
|
||||
{"type": "chunk", "data": "Error: prompt not found."}
|
||||
)
|
||||
continue
|
||||
|
||||
rag_chain = RagChain(
|
||||
question_template=str(prompt.text),
|
||||
top_k=int(config.get("top_k", 40)),
|
||||
top_p=float(config.get("top_p", 0.0)),
|
||||
temperature=float(config.get("temperature", 0.0)),
|
||||
retriever_max_docs=int(
|
||||
config.get("retriever_max_docs", 40)),
|
||||
reranker_max_results=int(
|
||||
config.get("reranker_max_results", 20)),
|
||||
retriever_max_docs=int(config.get("retriever_max_docs", 40)),
|
||||
reranker_max_results=int(config.get("reranker_max_results", 20)),
|
||||
)
|
||||
|
||||
async for chunk in rag_chain.stream(message):
|
||||
|
||||
Reference in New Issue
Block a user