optimize db connection
This commit is contained in:
31
db.py
31
db.py
@@ -1,19 +1,28 @@
|
||||
from typing import List
|
||||
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine, async_sessionmaker
|
||||
|
||||
from models.orm import Prompt
|
||||
from models.orm import Prompt as PromptORM
|
||||
from models.validation import Prompt as PromptSchema
|
||||
|
||||
|
||||
class DB:
|
||||
def __init__(self, db: str = "sqlite:///example.db"):
|
||||
self.engine = create_engine(db, connect_args={"check_same_thread": False})
|
||||
def __init__(self, engine: AsyncEngine):
|
||||
self._session_factory = async_sessionmaker(engine, expire_on_commit=False)
|
||||
|
||||
def get_prompts(self) -> List[Prompt]:
|
||||
with Session(self.engine) as session:
|
||||
return session.query(Prompt).all()
|
||||
async def get_prompts(self) -> List[PromptSchema]:
|
||||
async with self._session_factory() as session:
|
||||
result = await session.execute(select(PromptORM))
|
||||
prompts = result.scalars().all()
|
||||
return [PromptSchema(id=p.id, name=p.name, text=p.text) for p in prompts]
|
||||
|
||||
def get_prompt_by_id(self, prompt_id: int) -> Prompt | None:
|
||||
with Session(self.engine) as session:
|
||||
return session.query(Prompt).filter(Prompt.id == prompt_id).first()
|
||||
async def get_prompt_by_id(self, prompt_id: int) -> PromptSchema | None:
|
||||
async with self._session_factory() as session:
|
||||
result = await session.execute(
|
||||
select(PromptORM).where(PromptORM.id == prompt_id)
|
||||
)
|
||||
prompt = result.scalar_one_or_none()
|
||||
if prompt is None:
|
||||
return None
|
||||
return PromptSchema(id=prompt.id, name=prompt.name, text=prompt.text)
|
||||
|
||||
Reference in New Issue
Block a user