fix docstrings, update env dist

This commit is contained in:
Matteo Rosati
2026-01-21 09:42:13 +01:00
parent b91c09504d
commit 9f5d8a0a88
3 changed files with 77 additions and 14 deletions

View File

@@ -1 +1,5 @@
GOOGLE_APPLICATION_CREDENTIALS=credentials.json
PORT=8000
HOST=0.0.0.0
BASIC_AUTH_USERNAME=admin
BASIC_AUTH_PASSWORD=admin

50
app.py
View File

@@ -1,12 +1,21 @@
"""FastAPI application for Akern-Genai project.
This module provides the web application with WebSocket support
for streaming responses from the Gemini model.
"""
import os
import logging
from typing import Annotated
from fastapi import FastAPI, Request, WebSocket, Depends, HTTPException, status
from fastapi.security import HTTPBasic, HTTPBasicCredentials
from fastapi.templating import Jinja2Templates
from fastapi.staticfiles import StaticFiles
from main import generate
# Configure logging format and level
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
@@ -15,20 +24,26 @@ logging.basicConfig(
logger = logging.getLogger(__name__)
# Static files configuration
STATIC_DIR: str = "static"
TEMPLATES_DIR: str = "templates"
STATIC_DIR = "static"
TEMPLATES_DIR = "templates"
# Security configuration
security = HTTPBasic()
app = FastAPI()
app.mount(f"/{STATIC_DIR}", StaticFiles(directory=STATIC_DIR), name="static")
def verify_credentials(credentials: HTTPBasicCredentials = Depends(security)) -> str:
"""Verify HTTP Basic credentials against environment variables.
templates = Jinja2Templates(directory=os.path.join(STATIC_DIR, TEMPLATES_DIR))
Args:
credentials: HTTP Basic authentication credentials.
Returns:
str: The authenticated username.
def verify_credentials(credentials: HTTPBasicCredentials = Depends(security)):
Raises:
HTTPException: If credentials are invalid.
"""
correct_username = os.getenv("BASIC_AUTH_USERNAME")
correct_password = os.getenv("BASIC_AUTH_PASSWORD")
@@ -44,13 +59,34 @@ def verify_credentials(credentials: HTTPBasicCredentials = Depends(security)):
return credentials.username
# Initialize FastAPI application
app = FastAPI()
app.mount(f"/{STATIC_DIR}", StaticFiles(directory=STATIC_DIR), name="static")
templates = Jinja2Templates(directory=os.path.join(STATIC_DIR, TEMPLATES_DIR))
@app.get("/")
async def home(request: Request, username: Annotated[str, Depends(verify_credentials)]):
"""Render the main index page.
Args:
request: The incoming request object.
username: The authenticated username from HTTP Basic auth.
Returns:
TemplateResponse: The rendered HTML template.
"""
return templates.TemplateResponse("index.html", {"request": request})
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
"""Handle WebSocket connections for streaming responses.
Args:
websocket: The WebSocket connection.
"""
await websocket.accept()
while True:
data = await websocket.receive_text()

37
main.py
View File

@@ -1,19 +1,39 @@
"""Google Gemini API integration for Akern-Genai project.
This module provides functionality to generate content using Google's Gemini model
with Vertex AI RAG (Retrieval-Augmented Generation) support.
"""
from google import genai
from google.genai import types
from dotenv import load_dotenv
# Load environment variables from .env file
load_dotenv()
CORPUS = "projects/520464122471/locations/europe-west3/ragCorpora/2305843009213693952"
# Vertex AI RAG Corpus resource path
CORPUS: str = "projects/520464122471/locations/europe-west3/ragCorpora/2305843009213693952"
# Gemini model name
GEMINI_MODEL: str = "gemini-3-pro-preview"
def generate(prompt: str):
"""Generate content using Gemini model with RAG retrieval.
This function creates a streaming response from the Gemini model,
augmented with content from the configured RAG corpus.
Args:
prompt: The user's input prompt to generate content for.
Yields:
str: Text chunks from the generated response.
"""
client = genai.Client(
vertexai=True,
)
model = "gemini-3-pro-preview"
contents = [
types.Content(role="user", parts=[types.Part.from_text(text=prompt)]),
]
@@ -21,7 +41,8 @@ def generate(prompt: str):
types.Tool(
retrieval=types.Retrieval(
vertex_rag_store=types.VertexRagStore(
rag_resources=[types.VertexRagStoreRagResource(rag_corpus=CORPUS)],
rag_resources=[
types.VertexRagStoreRagResource(rag_corpus=CORPUS)],
)
)
)
@@ -32,14 +53,16 @@ def generate(prompt: str):
top_p=0.95,
max_output_tokens=65535,
safety_settings=[
types.SafetySetting(category="HARM_CATEGORY_HATE_SPEECH", threshold="OFF"),
types.SafetySetting(
category="HARM_CATEGORY_HATE_SPEECH", threshold="OFF"),
types.SafetySetting(
category="HARM_CATEGORY_DANGEROUS_CONTENT", threshold="OFF"
),
types.SafetySetting(
category="HARM_CATEGORY_SEXUALLY_EXPLICIT", threshold="OFF"
),
types.SafetySetting(category="HARM_CATEGORY_HARASSMENT", threshold="OFF"),
types.SafetySetting(
category="HARM_CATEGORY_HARASSMENT", threshold="OFF"),
],
tools=tools,
thinking_config=types.ThinkingConfig(
@@ -48,7 +71,7 @@ def generate(prompt: str):
)
for chunk in client.models.generate_content_stream(
model=model,
model=GEMINI_MODEL,
contents=contents,
config=generate_content_config,
):