fix docstrings, update env dist

This commit is contained in:
Matteo Rosati
2026-01-21 09:42:13 +01:00
parent b91c09504d
commit 9f5d8a0a88
3 changed files with 77 additions and 14 deletions

View File

@@ -1 +1,5 @@
GOOGLE_APPLICATION_CREDENTIALS=credentials.json GOOGLE_APPLICATION_CREDENTIALS=credentials.json
PORT=8000
HOST=0.0.0.0
BASIC_AUTH_USERNAME=admin
BASIC_AUTH_PASSWORD=admin

50
app.py
View File

@@ -1,12 +1,21 @@
"""FastAPI application for Akern-Genai project.
This module provides the web application with WebSocket support
for streaming responses from the Gemini model.
"""
import os import os
import logging import logging
from typing import Annotated from typing import Annotated
from fastapi import FastAPI, Request, WebSocket, Depends, HTTPException, status from fastapi import FastAPI, Request, WebSocket, Depends, HTTPException, status
from fastapi.security import HTTPBasic, HTTPBasicCredentials from fastapi.security import HTTPBasic, HTTPBasicCredentials
from fastapi.templating import Jinja2Templates from fastapi.templating import Jinja2Templates
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from main import generate from main import generate
# Configure logging format and level
logging.basicConfig( logging.basicConfig(
level=logging.INFO, level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
@@ -15,20 +24,26 @@ logging.basicConfig(
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Static files configuration
STATIC_DIR: str = "static"
TEMPLATES_DIR: str = "templates"
STATIC_DIR = "static" # Security configuration
TEMPLATES_DIR = "templates"
security = HTTPBasic() security = HTTPBasic()
app = FastAPI() def verify_credentials(credentials: HTTPBasicCredentials = Depends(security)) -> str:
app.mount(f"/{STATIC_DIR}", StaticFiles(directory=STATIC_DIR), name="static") """Verify HTTP Basic credentials against environment variables.
templates = Jinja2Templates(directory=os.path.join(STATIC_DIR, TEMPLATES_DIR)) Args:
credentials: HTTP Basic authentication credentials.
Returns:
str: The authenticated username.
def verify_credentials(credentials: HTTPBasicCredentials = Depends(security)): Raises:
HTTPException: If credentials are invalid.
"""
correct_username = os.getenv("BASIC_AUTH_USERNAME") correct_username = os.getenv("BASIC_AUTH_USERNAME")
correct_password = os.getenv("BASIC_AUTH_PASSWORD") correct_password = os.getenv("BASIC_AUTH_PASSWORD")
@@ -44,13 +59,34 @@ def verify_credentials(credentials: HTTPBasicCredentials = Depends(security)):
return credentials.username return credentials.username
# Initialize FastAPI application
app = FastAPI()
app.mount(f"/{STATIC_DIR}", StaticFiles(directory=STATIC_DIR), name="static")
templates = Jinja2Templates(directory=os.path.join(STATIC_DIR, TEMPLATES_DIR))
@app.get("/") @app.get("/")
async def home(request: Request, username: Annotated[str, Depends(verify_credentials)]): async def home(request: Request, username: Annotated[str, Depends(verify_credentials)]):
"""Render the main index page.
Args:
request: The incoming request object.
username: The authenticated username from HTTP Basic auth.
Returns:
TemplateResponse: The rendered HTML template.
"""
return templates.TemplateResponse("index.html", {"request": request}) return templates.TemplateResponse("index.html", {"request": request})
@app.websocket("/ws") @app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket): async def websocket_endpoint(websocket: WebSocket):
"""Handle WebSocket connections for streaming responses.
Args:
websocket: The WebSocket connection.
"""
await websocket.accept() await websocket.accept()
while True: while True:
data = await websocket.receive_text() data = await websocket.receive_text()

37
main.py
View File

@@ -1,19 +1,39 @@
"""Google Gemini API integration for Akern-Genai project.
This module provides functionality to generate content using Google's Gemini model
with Vertex AI RAG (Retrieval-Augmented Generation) support.
"""
from google import genai from google import genai
from google.genai import types from google.genai import types
from dotenv import load_dotenv from dotenv import load_dotenv
# Load environment variables from .env file
load_dotenv() load_dotenv()
CORPUS = "projects/520464122471/locations/europe-west3/ragCorpora/2305843009213693952" # Vertex AI RAG Corpus resource path
CORPUS: str = "projects/520464122471/locations/europe-west3/ragCorpora/2305843009213693952"
# Gemini model name
GEMINI_MODEL: str = "gemini-3-pro-preview"
def generate(prompt: str): def generate(prompt: str):
"""Generate content using Gemini model with RAG retrieval.
This function creates a streaming response from the Gemini model,
augmented with content from the configured RAG corpus.
Args:
prompt: The user's input prompt to generate content for.
Yields:
str: Text chunks from the generated response.
"""
client = genai.Client( client = genai.Client(
vertexai=True, vertexai=True,
) )
model = "gemini-3-pro-preview"
contents = [ contents = [
types.Content(role="user", parts=[types.Part.from_text(text=prompt)]), types.Content(role="user", parts=[types.Part.from_text(text=prompt)]),
] ]
@@ -21,7 +41,8 @@ def generate(prompt: str):
types.Tool( types.Tool(
retrieval=types.Retrieval( retrieval=types.Retrieval(
vertex_rag_store=types.VertexRagStore( vertex_rag_store=types.VertexRagStore(
rag_resources=[types.VertexRagStoreRagResource(rag_corpus=CORPUS)], rag_resources=[
types.VertexRagStoreRagResource(rag_corpus=CORPUS)],
) )
) )
) )
@@ -32,14 +53,16 @@ def generate(prompt: str):
top_p=0.95, top_p=0.95,
max_output_tokens=65535, max_output_tokens=65535,
safety_settings=[ safety_settings=[
types.SafetySetting(category="HARM_CATEGORY_HATE_SPEECH", threshold="OFF"), types.SafetySetting(
category="HARM_CATEGORY_HATE_SPEECH", threshold="OFF"),
types.SafetySetting( types.SafetySetting(
category="HARM_CATEGORY_DANGEROUS_CONTENT", threshold="OFF" category="HARM_CATEGORY_DANGEROUS_CONTENT", threshold="OFF"
), ),
types.SafetySetting( types.SafetySetting(
category="HARM_CATEGORY_SEXUALLY_EXPLICIT", threshold="OFF" category="HARM_CATEGORY_SEXUALLY_EXPLICIT", threshold="OFF"
), ),
types.SafetySetting(category="HARM_CATEGORY_HARASSMENT", threshold="OFF"), types.SafetySetting(
category="HARM_CATEGORY_HARASSMENT", threshold="OFF"),
], ],
tools=tools, tools=tools,
thinking_config=types.ThinkingConfig( thinking_config=types.ThinkingConfig(
@@ -48,7 +71,7 @@ def generate(prompt: str):
) )
for chunk in client.models.generate_content_stream( for chunk in client.models.generate_content_stream(
model=model, model=GEMINI_MODEL,
contents=contents, contents=contents,
config=generate_content_config, config=generate_content_config,
): ):