basic auth
This commit is contained in:
30
app.py
30
app.py
@@ -1,7 +1,11 @@
|
||||
import json
|
||||
import os
|
||||
import secrets
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI, Request, WebSocket
|
||||
from dotenv import load_dotenv
|
||||
from fastapi import Depends, FastAPI, HTTPException, Request, WebSocket, status
|
||||
from fastapi.security import HTTPBasic, HTTPBasicCredentials
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi.templating import Jinja2Templates
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
@@ -9,6 +13,8 @@ from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from chain import RagChain
|
||||
from db import DB
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
@@ -19,13 +25,27 @@ async def lifespan(app: FastAPI):
|
||||
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
security = HTTPBasic()
|
||||
|
||||
templates = Jinja2Templates(directory="templates")
|
||||
app.mount("/static", StaticFiles(directory="static"), name="static")
|
||||
|
||||
|
||||
def verify_credentials(credentials: HTTPBasicCredentials = Depends(security)):
|
||||
correct_username = secrets.compare_digest(credentials.username, os.environ["AUTH_USER"])
|
||||
correct_password = secrets.compare_digest(
|
||||
credentials.password, os.environ["PASSWORD"]
|
||||
)
|
||||
if not (correct_username and correct_password):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Unauthorized",
|
||||
headers={"WWW-Authenticate": "Basic"},
|
||||
)
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def home(request: Request):
|
||||
async def home(request: Request, _: None = Depends(verify_credentials)):
|
||||
db: DB = request.app.state.db
|
||||
prompts = await db.get_prompts()
|
||||
|
||||
@@ -86,3 +106,9 @@ async def websocket_endpoint(websocket: WebSocket):
|
||||
await websocket.send_json({"type": "end"})
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run("app:app", host="0.0.0.0", port=int(os.getenv("PORT", 8080)))
|
||||
|
||||
Reference in New Issue
Block a user