optimize db connection
This commit is contained in:
24
app.py
24
app.py
@@ -1,22 +1,33 @@
|
||||
import json
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI, Request, WebSocket
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi.templating import Jinja2Templates
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
|
||||
from chain import RagChain
|
||||
from db import DB
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
engine = create_async_engine("sqlite+aiosqlite:///example.db")
|
||||
app.state.db = DB(engine)
|
||||
yield
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
templates = Jinja2Templates(directory="templates")
|
||||
app.mount("/static", StaticFiles(directory="static"), name="static")
|
||||
|
||||
|
||||
@app.get("/")
|
||||
def home(request: Request):
|
||||
db = DB()
|
||||
prompts = db.get_prompts()
|
||||
async def home(request: Request):
|
||||
db: DB = request.app.state.db
|
||||
prompts = await db.get_prompts()
|
||||
|
||||
return templates.TemplateResponse(
|
||||
"index.html", {"request": request, "prompts": prompts}
|
||||
@@ -27,6 +38,8 @@ def home(request: Request):
|
||||
async def websocket_endpoint(websocket: WebSocket):
|
||||
await websocket.accept()
|
||||
|
||||
db: DB = websocket.app.state.db
|
||||
|
||||
try:
|
||||
while True:
|
||||
raw_message = await websocket.receive_text()
|
||||
@@ -40,9 +53,8 @@ async def websocket_endpoint(websocket: WebSocket):
|
||||
message = raw_message
|
||||
config = {}
|
||||
|
||||
db = DB()
|
||||
prompt_id = int(config.get("prompt_id", 0))
|
||||
prompt = db.get_prompt_by_id(prompt_id)
|
||||
prompt = await db.get_prompt_by_id(prompt_id)
|
||||
|
||||
if prompt is None:
|
||||
await websocket.send_json(
|
||||
|
||||
Reference in New Issue
Block a user