optimize db connection

This commit is contained in:
Matteo Rosati
2026-02-18 16:39:41 +01:00
parent 3c6c367600
commit 39eb4f4f01
6 changed files with 56 additions and 17 deletions

24
app.py
View File

@@ -1,22 +1,33 @@
import json
from contextlib import asynccontextmanager
from fastapi import FastAPI, Request, WebSocket
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from sqlalchemy.ext.asyncio import create_async_engine
from chain import RagChain
from db import DB
app = FastAPI()
@asynccontextmanager
async def lifespan(app: FastAPI):
engine = create_async_engine("sqlite+aiosqlite:///example.db")
app.state.db = DB(engine)
yield
await engine.dispose()
app = FastAPI(lifespan=lifespan)
templates = Jinja2Templates(directory="templates")
app.mount("/static", StaticFiles(directory="static"), name="static")
@app.get("/")
def home(request: Request):
db = DB()
prompts = db.get_prompts()
async def home(request: Request):
db: DB = request.app.state.db
prompts = await db.get_prompts()
return templates.TemplateResponse(
"index.html", {"request": request, "prompts": prompts}
@@ -27,6 +38,8 @@ def home(request: Request):
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
db: DB = websocket.app.state.db
try:
while True:
raw_message = await websocket.receive_text()
@@ -40,9 +53,8 @@ async def websocket_endpoint(websocket: WebSocket):
message = raw_message
config = {}
db = DB()
prompt_id = int(config.get("prompt_id", 0))
prompt = db.get_prompt_by_id(prompt_id)
prompt = await db.get_prompt_by_id(prompt_id)
if prompt is None:
await websocket.send_json(

31
db.py
View File

@@ -1,19 +1,28 @@
from typing import List
from sqlalchemy import create_engine
from sqlalchemy.orm import Session
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncEngine, async_sessionmaker
from models.orm import Prompt
from models.orm import Prompt as PromptORM
from models.validation import Prompt as PromptSchema
class DB:
def __init__(self, db: str = "sqlite:///example.db"):
self.engine = create_engine(db, connect_args={"check_same_thread": False})
def __init__(self, engine: AsyncEngine):
self._session_factory = async_sessionmaker(engine, expire_on_commit=False)
def get_prompts(self) -> List[Prompt]:
with Session(self.engine) as session:
return session.query(Prompt).all()
async def get_prompts(self) -> List[PromptSchema]:
async with self._session_factory() as session:
result = await session.execute(select(PromptORM))
prompts = result.scalars().all()
return [PromptSchema(id=p.id, name=p.name, text=p.text) for p in prompts]
def get_prompt_by_id(self, prompt_id: int) -> Prompt | None:
with Session(self.engine) as session:
return session.query(Prompt).filter(Prompt.id == prompt_id).first()
async def get_prompt_by_id(self, prompt_id: int) -> PromptSchema | None:
async with self._session_factory() as session:
result = await session.execute(
select(PromptORM).where(PromptORM.id == prompt_id)
)
prompt = result.scalar_one_or_none()
if prompt is None:
return None
return PromptSchema(id=prompt.id, name=prompt.name, text=prompt.text)

View File

@@ -2,5 +2,6 @@ from pydantic import BaseModel
class Prompt(BaseModel):
id: int
name: str
text: str

View File

@@ -18,6 +18,8 @@ dependencies = [
"fastapi[standard]",
"pydantic>=2.12.5",
"sqlalchemy>=2.0.46",
"aiosqlite>=0.22.1",
"greenlet>=3.3.1",
]
[dependency-groups]

View File

@@ -1,6 +1,7 @@
aiohappyeyeballs==2.6.1
aiohttp==3.13.3
aiosignal==1.4.0
aiosqlite==0.22.1
annotated-doc==0.0.4
annotated-types==0.7.0
anyio==4.12.1
@@ -38,6 +39,7 @@ google-crc32c==1.8.0
google-genai==1.63.0
google-resumable-media==2.8.0
googleapis-common-protos==1.72.0
greenlet==3.3.1
grpc-google-iam-v1==0.14.3
grpcio==1.78.0
grpcio-status==1.78.0

13
uv.lock generated
View File

@@ -95,13 +95,24 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/fb/76/641ae371508676492379f16e2fa48f4e2c11741bd63c48be4b12a6b09cba/aiosignal-1.4.0-py3-none-any.whl", hash = "sha256:053243f8b92b990551949e63930a839ff0cf0b0ebbe0597b0f3fb19e1a0fe82e", size = 7490, upload-time = "2025-07-03T22:54:42.156Z" },
]
[[package]]
name = "aiosqlite"
version = "0.22.1"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/4e/8a/64761f4005f17809769d23e518d915db74e6310474e733e3593cfc854ef1/aiosqlite-0.22.1.tar.gz", hash = "sha256:043e0bd78d32888c0a9ca90fc788b38796843360c855a7262a532813133a0650", size = 14821, upload-time = "2025-12-23T19:25:43.997Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/00/b7/e3bf5133d697a08128598c8d0abc5e16377b51465a33756de24fa7dee953/aiosqlite-0.22.1-py3-none-any.whl", hash = "sha256:21c002eb13823fad740196c5a2e9d8e62f6243bd9e7e4a1f87fb5e44ecb4fceb", size = 17405, upload-time = "2025-12-23T19:25:42.139Z" },
]
[[package]]
name = "akern"
version = "0.1.0"
source = { virtual = "." }
dependencies = [
{ name = "aiosqlite" },
{ name = "fastapi", extra = ["standard"] },
{ name = "google-cloud-discoveryengine" },
{ name = "greenlet" },
{ name = "langchain" },
{ name = "langchain-community" },
{ name = "langchain-core" },
@@ -121,9 +132,11 @@ dev = [
[package.metadata]
requires-dist = [
{ name = "aiosqlite", specifier = ">=0.22.1" },
{ name = "fastapi", specifier = ">=0.129.0" },
{ name = "fastapi", extras = ["standard"] },
{ name = "google-cloud-discoveryengine", specifier = ">=0.17.0" },
{ name = "greenlet", specifier = ">=3.3.1" },
{ name = "langchain", specifier = ">=1.2.10" },
{ name = "langchain-community", specifier = ">=0.4.1" },
{ name = "langchain-core", specifier = ">=1.2.13" },