29 lines
1.1 KiB
Python
29 lines
1.1 KiB
Python
from typing import List
|
|
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncEngine, async_sessionmaker
|
|
|
|
from models.orm import Prompt as PromptORM
|
|
from models.validation import Prompt as PromptSchema
|
|
|
|
|
|
class DB:
|
|
def __init__(self, engine: AsyncEngine):
|
|
self._session_factory = async_sessionmaker(engine, expire_on_commit=False)
|
|
|
|
async def get_prompts(self) -> List[PromptSchema]:
|
|
async with self._session_factory() as session:
|
|
result = await session.execute(select(PromptORM))
|
|
prompts = result.scalars().all()
|
|
return [PromptSchema(id=p.id, name=p.name, text=p.text) for p in prompts]
|
|
|
|
async def get_prompt_by_id(self, prompt_id: int) -> PromptSchema | None:
|
|
async with self._session_factory() as session:
|
|
result = await session.execute(
|
|
select(PromptORM).where(PromptORM.id == prompt_id)
|
|
)
|
|
prompt = result.scalar_one_or_none()
|
|
if prompt is None:
|
|
return None
|
|
return PromptSchema(id=prompt.id, name=prompt.name, text=prompt.text)
|