42 lines
1.0 KiB
Python
42 lines
1.0 KiB
Python
import os
|
|
import logging
|
|
from fastapi import FastAPI, Request, WebSocket
|
|
from fastapi.templating import Jinja2Templates
|
|
from fastapi.staticfiles import StaticFiles
|
|
from main import generate
|
|
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
|
datefmt="%Y-%m-%d %H:%M:%S",
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
STATIC_DIR = "static"
|
|
TEMPLATES_DIR = "templates"
|
|
|
|
|
|
app = FastAPI()
|
|
app.mount(f"/{STATIC_DIR}", StaticFiles(directory=STATIC_DIR), name="static")
|
|
|
|
templates = Jinja2Templates(directory=os.path.join(STATIC_DIR, TEMPLATES_DIR))
|
|
|
|
|
|
@app.get("/")
|
|
async def home(request: Request):
|
|
return templates.TemplateResponse("index.html", {"request": request})
|
|
|
|
|
|
@app.websocket("/ws")
|
|
async def websocket_endpoint(websocket: WebSocket):
|
|
await websocket.accept()
|
|
while True:
|
|
data = await websocket.receive_text()
|
|
|
|
for chunk in generate(data):
|
|
await websocket.send_text(chunk)
|
|
|
|
await websocket.send_text("<<END>>")
|