From 39eb4f4f010c52b7b95a2c2ffcca9265d4201554 Mon Sep 17 00:00:00 2001 From: Matteo Rosati Date: Wed, 18 Feb 2026 16:39:41 +0100 Subject: [PATCH] optimize db connection --- app.py | 24 ++++++++++++++++++------ db.py | 31 ++++++++++++++++++++----------- models/validation.py | 1 + pyproject.toml | 2 ++ requirements.txt | 2 ++ uv.lock | 13 +++++++++++++ 6 files changed, 56 insertions(+), 17 deletions(-) diff --git a/app.py b/app.py index 8b18d7c..77d161a 100644 --- a/app.py +++ b/app.py @@ -1,22 +1,33 @@ import json +from contextlib import asynccontextmanager from fastapi import FastAPI, Request, WebSocket from fastapi.staticfiles import StaticFiles from fastapi.templating import Jinja2Templates +from sqlalchemy.ext.asyncio import create_async_engine from chain import RagChain from db import DB -app = FastAPI() + +@asynccontextmanager +async def lifespan(app: FastAPI): + engine = create_async_engine("sqlite+aiosqlite:///example.db") + app.state.db = DB(engine) + yield + await engine.dispose() + + +app = FastAPI(lifespan=lifespan) templates = Jinja2Templates(directory="templates") app.mount("/static", StaticFiles(directory="static"), name="static") @app.get("/") -def home(request: Request): - db = DB() - prompts = db.get_prompts() +async def home(request: Request): + db: DB = request.app.state.db + prompts = await db.get_prompts() return templates.TemplateResponse( "index.html", {"request": request, "prompts": prompts} @@ -27,6 +38,8 @@ def home(request: Request): async def websocket_endpoint(websocket: WebSocket): await websocket.accept() + db: DB = websocket.app.state.db + try: while True: raw_message = await websocket.receive_text() @@ -40,9 +53,8 @@ async def websocket_endpoint(websocket: WebSocket): message = raw_message config = {} - db = DB() prompt_id = int(config.get("prompt_id", 0)) - prompt = db.get_prompt_by_id(prompt_id) + prompt = await db.get_prompt_by_id(prompt_id) if prompt is None: await websocket.send_json( diff --git a/db.py b/db.py index a90759e..78ef6f6 100644 --- a/db.py +++ b/db.py @@ -1,19 +1,28 @@ from typing import List -from sqlalchemy import create_engine -from sqlalchemy.orm import Session +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncEngine, async_sessionmaker -from models.orm import Prompt +from models.orm import Prompt as PromptORM +from models.validation import Prompt as PromptSchema class DB: - def __init__(self, db: str = "sqlite:///example.db"): - self.engine = create_engine(db, connect_args={"check_same_thread": False}) + def __init__(self, engine: AsyncEngine): + self._session_factory = async_sessionmaker(engine, expire_on_commit=False) - def get_prompts(self) -> List[Prompt]: - with Session(self.engine) as session: - return session.query(Prompt).all() + async def get_prompts(self) -> List[PromptSchema]: + async with self._session_factory() as session: + result = await session.execute(select(PromptORM)) + prompts = result.scalars().all() + return [PromptSchema(id=p.id, name=p.name, text=p.text) for p in prompts] - def get_prompt_by_id(self, prompt_id: int) -> Prompt | None: - with Session(self.engine) as session: - return session.query(Prompt).filter(Prompt.id == prompt_id).first() + async def get_prompt_by_id(self, prompt_id: int) -> PromptSchema | None: + async with self._session_factory() as session: + result = await session.execute( + select(PromptORM).where(PromptORM.id == prompt_id) + ) + prompt = result.scalar_one_or_none() + if prompt is None: + return None + return PromptSchema(id=prompt.id, name=prompt.name, text=prompt.text) diff --git a/models/validation.py b/models/validation.py index 91d1bf8..a57fc82 100644 --- a/models/validation.py +++ b/models/validation.py @@ -2,5 +2,6 @@ from pydantic import BaseModel class Prompt(BaseModel): + id: int name: str text: str diff --git a/pyproject.toml b/pyproject.toml index 73f58d3..5982a9e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,6 +18,8 @@ dependencies = [ "fastapi[standard]", "pydantic>=2.12.5", "sqlalchemy>=2.0.46", + "aiosqlite>=0.22.1", + "greenlet>=3.3.1", ] [dependency-groups] diff --git a/requirements.txt b/requirements.txt index 2a32256..8001c82 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,7 @@ aiohappyeyeballs==2.6.1 aiohttp==3.13.3 aiosignal==1.4.0 +aiosqlite==0.22.1 annotated-doc==0.0.4 annotated-types==0.7.0 anyio==4.12.1 @@ -38,6 +39,7 @@ google-crc32c==1.8.0 google-genai==1.63.0 google-resumable-media==2.8.0 googleapis-common-protos==1.72.0 +greenlet==3.3.1 grpc-google-iam-v1==0.14.3 grpcio==1.78.0 grpcio-status==1.78.0 diff --git a/uv.lock b/uv.lock index ea3dc78..eaacb4d 100644 --- a/uv.lock +++ b/uv.lock @@ -95,13 +95,24 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/fb/76/641ae371508676492379f16e2fa48f4e2c11741bd63c48be4b12a6b09cba/aiosignal-1.4.0-py3-none-any.whl", hash = "sha256:053243f8b92b990551949e63930a839ff0cf0b0ebbe0597b0f3fb19e1a0fe82e", size = 7490, upload-time = "2025-07-03T22:54:42.156Z" }, ] +[[package]] +name = "aiosqlite" +version = "0.22.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/4e/8a/64761f4005f17809769d23e518d915db74e6310474e733e3593cfc854ef1/aiosqlite-0.22.1.tar.gz", hash = "sha256:043e0bd78d32888c0a9ca90fc788b38796843360c855a7262a532813133a0650", size = 14821, upload-time = "2025-12-23T19:25:43.997Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/00/b7/e3bf5133d697a08128598c8d0abc5e16377b51465a33756de24fa7dee953/aiosqlite-0.22.1-py3-none-any.whl", hash = "sha256:21c002eb13823fad740196c5a2e9d8e62f6243bd9e7e4a1f87fb5e44ecb4fceb", size = 17405, upload-time = "2025-12-23T19:25:42.139Z" }, +] + [[package]] name = "akern" version = "0.1.0" source = { virtual = "." } dependencies = [ + { name = "aiosqlite" }, { name = "fastapi", extra = ["standard"] }, { name = "google-cloud-discoveryengine" }, + { name = "greenlet" }, { name = "langchain" }, { name = "langchain-community" }, { name = "langchain-core" }, @@ -121,9 +132,11 @@ dev = [ [package.metadata] requires-dist = [ + { name = "aiosqlite", specifier = ">=0.22.1" }, { name = "fastapi", specifier = ">=0.129.0" }, { name = "fastapi", extras = ["standard"] }, { name = "google-cloud-discoveryengine", specifier = ">=0.17.0" }, + { name = "greenlet", specifier = ">=3.3.1" }, { name = "langchain", specifier = ">=1.2.10" }, { name = "langchain-community", specifier = ">=0.4.1" }, { name = "langchain-core", specifier = ">=1.2.13" },