Files
AKERN-Langchain/app.py
2026-02-18 18:09:45 +01:00

117 lines
3.7 KiB
Python

import json
import os
import secrets
from contextlib import asynccontextmanager
from dotenv import load_dotenv
from fastapi import Depends, FastAPI, HTTPException, Request, WebSocket, status
from fastapi.security import HTTPBasic, HTTPBasicCredentials
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from sqlalchemy.ext.asyncio import create_async_engine
from chain import RagChain
from db import DB
load_dotenv()
@asynccontextmanager
async def lifespan(app: FastAPI):
engine = create_async_engine("sqlite+aiosqlite:///example.db")
app.state.db = DB(engine)
yield
await engine.dispose()
app = FastAPI(lifespan=lifespan)
security = HTTPBasic()
templates = Jinja2Templates(directory="templates")
app.mount("/static", StaticFiles(directory="static"), name="static")
def verify_credentials(credentials: HTTPBasicCredentials = Depends(security)):
correct_username = secrets.compare_digest(
credentials.username, os.environ["AUTH_USER"]
)
correct_password = secrets.compare_digest(
credentials.password, os.environ["PASSWORD"]
)
if not (correct_username and correct_password):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Unauthorized",
headers={"WWW-Authenticate": "Basic"},
)
@app.get("/")
async def home(request: Request, _: None = Depends(verify_credentials)):
db: DB = request.app.state.db
prompts = await db.get_prompts()
return templates.TemplateResponse(
"index.html", {"request": request, "prompts": prompts}
)
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
db: DB = websocket.app.state.db
try:
while True:
raw_message = await websocket.receive_text()
try:
await websocket.send_json({"type": "start"})
try:
payload = json.loads(raw_message)
message = payload.get("message", "")
config = payload.get("config", {})
except json.JSONDecodeError:
message = raw_message
config = {}
prompt_id = int(config.get("prompt_id", 0))
prompt = await db.get_prompt_by_id(prompt_id)
if prompt is None:
await websocket.send_json(
{"type": "chunk", "data": "Error: prompt not found."}
)
continue
rag_chain = RagChain(
question_template=str(prompt.text),
top_k=int(config.get("top_k", 40)),
top_p=float(config.get("top_p", 0.0)),
temperature=float(config.get("temperature", 0.0)),
retriever_max_docs=int(config.get("retriever_max_docs", 40)),
reranker_max_results=int(config.get("reranker_max_results", 20)),
)
async for chunk in rag_chain.stream(message):
await websocket.send_json({"type": "chunk", "data": chunk})
await websocket.send_json(
{
"type": "report",
"sources": rag_chain.getSources(),
"reranked_sources": rag_chain.getRankedSources(),
"rephrased_question": rag_chain.getRephrasedQuestion(),
}
)
finally:
await websocket.send_json({"type": "end"})
except Exception:
pass
if __name__ == "__main__":
import uvicorn
uvicorn.run("app:app", host="0.0.0.0", port=int(os.getenv("PORT", 8080)))