optimize db connection

This commit is contained in:
Matteo Rosati
2026-02-18 16:39:41 +01:00
parent 3c6c367600
commit 39eb4f4f01
6 changed files with 56 additions and 17 deletions

24
app.py
View File

@@ -1,22 +1,33 @@
import json import json
from contextlib import asynccontextmanager
from fastapi import FastAPI, Request, WebSocket from fastapi import FastAPI, Request, WebSocket
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates from fastapi.templating import Jinja2Templates
from sqlalchemy.ext.asyncio import create_async_engine
from chain import RagChain from chain import RagChain
from db import DB from db import DB
app = FastAPI()
@asynccontextmanager
async def lifespan(app: FastAPI):
engine = create_async_engine("sqlite+aiosqlite:///example.db")
app.state.db = DB(engine)
yield
await engine.dispose()
app = FastAPI(lifespan=lifespan)
templates = Jinja2Templates(directory="templates") templates = Jinja2Templates(directory="templates")
app.mount("/static", StaticFiles(directory="static"), name="static") app.mount("/static", StaticFiles(directory="static"), name="static")
@app.get("/") @app.get("/")
def home(request: Request): async def home(request: Request):
db = DB() db: DB = request.app.state.db
prompts = db.get_prompts() prompts = await db.get_prompts()
return templates.TemplateResponse( return templates.TemplateResponse(
"index.html", {"request": request, "prompts": prompts} "index.html", {"request": request, "prompts": prompts}
@@ -27,6 +38,8 @@ def home(request: Request):
async def websocket_endpoint(websocket: WebSocket): async def websocket_endpoint(websocket: WebSocket):
await websocket.accept() await websocket.accept()
db: DB = websocket.app.state.db
try: try:
while True: while True:
raw_message = await websocket.receive_text() raw_message = await websocket.receive_text()
@@ -40,9 +53,8 @@ async def websocket_endpoint(websocket: WebSocket):
message = raw_message message = raw_message
config = {} config = {}
db = DB()
prompt_id = int(config.get("prompt_id", 0)) prompt_id = int(config.get("prompt_id", 0))
prompt = db.get_prompt_by_id(prompt_id) prompt = await db.get_prompt_by_id(prompt_id)
if prompt is None: if prompt is None:
await websocket.send_json( await websocket.send_json(

31
db.py
View File

@@ -1,19 +1,28 @@
from typing import List from typing import List
from sqlalchemy import create_engine from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.ext.asyncio import AsyncEngine, async_sessionmaker
from models.orm import Prompt from models.orm import Prompt as PromptORM
from models.validation import Prompt as PromptSchema
class DB: class DB:
def __init__(self, db: str = "sqlite:///example.db"): def __init__(self, engine: AsyncEngine):
self.engine = create_engine(db, connect_args={"check_same_thread": False}) self._session_factory = async_sessionmaker(engine, expire_on_commit=False)
def get_prompts(self) -> List[Prompt]: async def get_prompts(self) -> List[PromptSchema]:
with Session(self.engine) as session: async with self._session_factory() as session:
return session.query(Prompt).all() result = await session.execute(select(PromptORM))
prompts = result.scalars().all()
return [PromptSchema(id=p.id, name=p.name, text=p.text) for p in prompts]
def get_prompt_by_id(self, prompt_id: int) -> Prompt | None: async def get_prompt_by_id(self, prompt_id: int) -> PromptSchema | None:
with Session(self.engine) as session: async with self._session_factory() as session:
return session.query(Prompt).filter(Prompt.id == prompt_id).first() result = await session.execute(
select(PromptORM).where(PromptORM.id == prompt_id)
)
prompt = result.scalar_one_or_none()
if prompt is None:
return None
return PromptSchema(id=prompt.id, name=prompt.name, text=prompt.text)

View File

@@ -2,5 +2,6 @@ from pydantic import BaseModel
class Prompt(BaseModel): class Prompt(BaseModel):
id: int
name: str name: str
text: str text: str

View File

@@ -18,6 +18,8 @@ dependencies = [
"fastapi[standard]", "fastapi[standard]",
"pydantic>=2.12.5", "pydantic>=2.12.5",
"sqlalchemy>=2.0.46", "sqlalchemy>=2.0.46",
"aiosqlite>=0.22.1",
"greenlet>=3.3.1",
] ]
[dependency-groups] [dependency-groups]

View File

@@ -1,6 +1,7 @@
aiohappyeyeballs==2.6.1 aiohappyeyeballs==2.6.1
aiohttp==3.13.3 aiohttp==3.13.3
aiosignal==1.4.0 aiosignal==1.4.0
aiosqlite==0.22.1
annotated-doc==0.0.4 annotated-doc==0.0.4
annotated-types==0.7.0 annotated-types==0.7.0
anyio==4.12.1 anyio==4.12.1
@@ -38,6 +39,7 @@ google-crc32c==1.8.0
google-genai==1.63.0 google-genai==1.63.0
google-resumable-media==2.8.0 google-resumable-media==2.8.0
googleapis-common-protos==1.72.0 googleapis-common-protos==1.72.0
greenlet==3.3.1
grpc-google-iam-v1==0.14.3 grpc-google-iam-v1==0.14.3
grpcio==1.78.0 grpcio==1.78.0
grpcio-status==1.78.0 grpcio-status==1.78.0

13
uv.lock generated
View File

@@ -95,13 +95,24 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/fb/76/641ae371508676492379f16e2fa48f4e2c11741bd63c48be4b12a6b09cba/aiosignal-1.4.0-py3-none-any.whl", hash = "sha256:053243f8b92b990551949e63930a839ff0cf0b0ebbe0597b0f3fb19e1a0fe82e", size = 7490, upload-time = "2025-07-03T22:54:42.156Z" }, { url = "https://files.pythonhosted.org/packages/fb/76/641ae371508676492379f16e2fa48f4e2c11741bd63c48be4b12a6b09cba/aiosignal-1.4.0-py3-none-any.whl", hash = "sha256:053243f8b92b990551949e63930a839ff0cf0b0ebbe0597b0f3fb19e1a0fe82e", size = 7490, upload-time = "2025-07-03T22:54:42.156Z" },
] ]
[[package]]
name = "aiosqlite"
version = "0.22.1"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/4e/8a/64761f4005f17809769d23e518d915db74e6310474e733e3593cfc854ef1/aiosqlite-0.22.1.tar.gz", hash = "sha256:043e0bd78d32888c0a9ca90fc788b38796843360c855a7262a532813133a0650", size = 14821, upload-time = "2025-12-23T19:25:43.997Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/00/b7/e3bf5133d697a08128598c8d0abc5e16377b51465a33756de24fa7dee953/aiosqlite-0.22.1-py3-none-any.whl", hash = "sha256:21c002eb13823fad740196c5a2e9d8e62f6243bd9e7e4a1f87fb5e44ecb4fceb", size = 17405, upload-time = "2025-12-23T19:25:42.139Z" },
]
[[package]] [[package]]
name = "akern" name = "akern"
version = "0.1.0" version = "0.1.0"
source = { virtual = "." } source = { virtual = "." }
dependencies = [ dependencies = [
{ name = "aiosqlite" },
{ name = "fastapi", extra = ["standard"] }, { name = "fastapi", extra = ["standard"] },
{ name = "google-cloud-discoveryengine" }, { name = "google-cloud-discoveryengine" },
{ name = "greenlet" },
{ name = "langchain" }, { name = "langchain" },
{ name = "langchain-community" }, { name = "langchain-community" },
{ name = "langchain-core" }, { name = "langchain-core" },
@@ -121,9 +132,11 @@ dev = [
[package.metadata] [package.metadata]
requires-dist = [ requires-dist = [
{ name = "aiosqlite", specifier = ">=0.22.1" },
{ name = "fastapi", specifier = ">=0.129.0" }, { name = "fastapi", specifier = ">=0.129.0" },
{ name = "fastapi", extras = ["standard"] }, { name = "fastapi", extras = ["standard"] },
{ name = "google-cloud-discoveryengine", specifier = ">=0.17.0" }, { name = "google-cloud-discoveryengine", specifier = ">=0.17.0" },
{ name = "greenlet", specifier = ">=3.3.1" },
{ name = "langchain", specifier = ">=1.2.10" }, { name = "langchain", specifier = ">=1.2.10" },
{ name = "langchain-community", specifier = ">=0.4.1" }, { name = "langchain-community", specifier = ">=0.4.1" },
{ name = "langchain-core", specifier = ">=1.2.13" }, { name = "langchain-core", specifier = ">=1.2.13" },