add db
This commit is contained in:
21
app.py
21
app.py
@@ -5,6 +5,7 @@ from fastapi.templating import Jinja2Templates
|
||||
from fastapi import Request
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
|
||||
from db import DB
|
||||
from chain import RagChain
|
||||
from pprint import pprint
|
||||
|
||||
@@ -16,7 +17,18 @@ app.mount("/static", StaticFiles(directory="static"), name="static")
|
||||
|
||||
@app.get("/")
|
||||
def home(request: Request):
|
||||
return templates.TemplateResponse("index.html", {"request": request})
|
||||
db = DB()
|
||||
prompts = db.get_prompts()
|
||||
|
||||
return templates.TemplateResponse("index.html", {"request": request, "prompts": prompts})
|
||||
|
||||
|
||||
@app.get("/prova")
|
||||
def prova(request: Request):
|
||||
cursor = DB()
|
||||
prompts = cursor.get_prompts()
|
||||
|
||||
return prompts
|
||||
|
||||
|
||||
@app.websocket("/ws")
|
||||
@@ -40,8 +52,10 @@ async def websocket_endpoint(websocket: WebSocket):
|
||||
top_k=int(config.get("top_k", 40)),
|
||||
top_p=float(config.get("top_p", 0.0)),
|
||||
temperature=float(config.get("temperature", 0.0)),
|
||||
retriever_max_docs=int(config.get("retriever_max_docs", 40)),
|
||||
reranker_max_results=int(config.get("reranker_max_results", 20)),
|
||||
retriever_max_docs=int(
|
||||
config.get("retriever_max_docs", 40)),
|
||||
reranker_max_results=int(
|
||||
config.get("reranker_max_results", 20)),
|
||||
)
|
||||
|
||||
async for chunk in rag_chain.stream(message):
|
||||
@@ -52,6 +66,7 @@ async def websocket_endpoint(websocket: WebSocket):
|
||||
"type": "report",
|
||||
"sources": rag_chain.getSources(),
|
||||
"reranked_sources": rag_chain.getRankedSources(),
|
||||
"rephrased_question": rag_chain.getRephrasedQuestion(),
|
||||
}
|
||||
)
|
||||
finally:
|
||||
|
||||
Reference in New Issue
Block a user