diff --git a/.env.dist b/.env.dist index e9743d9..72768ba 100644 --- a/.env.dist +++ b/.env.dist @@ -1 +1,5 @@ GOOGLE_APPLICATION_CREDENTIALS=credentials.json +PORT=8000 +HOST=0.0.0.0 +BASIC_AUTH_USERNAME=admin +BASIC_AUTH_PASSWORD=admin \ No newline at end of file diff --git a/app.py b/app.py index 2c5dd2d..9de7fe9 100644 --- a/app.py +++ b/app.py @@ -1,12 +1,21 @@ +"""FastAPI application for Akern-Genai project. + +This module provides the web application with WebSocket support +for streaming responses from the Gemini model. +""" + import os import logging from typing import Annotated + from fastapi import FastAPI, Request, WebSocket, Depends, HTTPException, status from fastapi.security import HTTPBasic, HTTPBasicCredentials from fastapi.templating import Jinja2Templates from fastapi.staticfiles import StaticFiles + from main import generate +# Configure logging format and level logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", @@ -15,20 +24,26 @@ logging.basicConfig( logger = logging.getLogger(__name__) +# Static files configuration +STATIC_DIR: str = "static" +TEMPLATES_DIR: str = "templates" -STATIC_DIR = "static" -TEMPLATES_DIR = "templates" - +# Security configuration security = HTTPBasic() -app = FastAPI() -app.mount(f"/{STATIC_DIR}", StaticFiles(directory=STATIC_DIR), name="static") +def verify_credentials(credentials: HTTPBasicCredentials = Depends(security)) -> str: + """Verify HTTP Basic credentials against environment variables. -templates = Jinja2Templates(directory=os.path.join(STATIC_DIR, TEMPLATES_DIR)) + Args: + credentials: HTTP Basic authentication credentials. + Returns: + str: The authenticated username. -def verify_credentials(credentials: HTTPBasicCredentials = Depends(security)): + Raises: + HTTPException: If credentials are invalid. + """ correct_username = os.getenv("BASIC_AUTH_USERNAME") correct_password = os.getenv("BASIC_AUTH_PASSWORD") @@ -44,13 +59,34 @@ def verify_credentials(credentials: HTTPBasicCredentials = Depends(security)): return credentials.username +# Initialize FastAPI application +app = FastAPI() +app.mount(f"/{STATIC_DIR}", StaticFiles(directory=STATIC_DIR), name="static") + +templates = Jinja2Templates(directory=os.path.join(STATIC_DIR, TEMPLATES_DIR)) + + @app.get("/") async def home(request: Request, username: Annotated[str, Depends(verify_credentials)]): + """Render the main index page. + + Args: + request: The incoming request object. + username: The authenticated username from HTTP Basic auth. + + Returns: + TemplateResponse: The rendered HTML template. + """ return templates.TemplateResponse("index.html", {"request": request}) @app.websocket("/ws") async def websocket_endpoint(websocket: WebSocket): + """Handle WebSocket connections for streaming responses. + + Args: + websocket: The WebSocket connection. + """ await websocket.accept() while True: data = await websocket.receive_text() diff --git a/main.py b/main.py index f10a67e..eb2ab9a 100644 --- a/main.py +++ b/main.py @@ -1,19 +1,39 @@ +"""Google Gemini API integration for Akern-Genai project. + +This module provides functionality to generate content using Google's Gemini model +with Vertex AI RAG (Retrieval-Augmented Generation) support. +""" + from google import genai from google.genai import types from dotenv import load_dotenv - +# Load environment variables from .env file load_dotenv() -CORPUS = "projects/520464122471/locations/europe-west3/ragCorpora/2305843009213693952" +# Vertex AI RAG Corpus resource path +CORPUS: str = "projects/520464122471/locations/europe-west3/ragCorpora/2305843009213693952" + +# Gemini model name +GEMINI_MODEL: str = "gemini-3-pro-preview" def generate(prompt: str): + """Generate content using Gemini model with RAG retrieval. + + This function creates a streaming response from the Gemini model, + augmented with content from the configured RAG corpus. + + Args: + prompt: The user's input prompt to generate content for. + + Yields: + str: Text chunks from the generated response. + """ client = genai.Client( vertexai=True, ) - model = "gemini-3-pro-preview" contents = [ types.Content(role="user", parts=[types.Part.from_text(text=prompt)]), ] @@ -21,7 +41,8 @@ def generate(prompt: str): types.Tool( retrieval=types.Retrieval( vertex_rag_store=types.VertexRagStore( - rag_resources=[types.VertexRagStoreRagResource(rag_corpus=CORPUS)], + rag_resources=[ + types.VertexRagStoreRagResource(rag_corpus=CORPUS)], ) ) ) @@ -32,14 +53,16 @@ def generate(prompt: str): top_p=0.95, max_output_tokens=65535, safety_settings=[ - types.SafetySetting(category="HARM_CATEGORY_HATE_SPEECH", threshold="OFF"), + types.SafetySetting( + category="HARM_CATEGORY_HATE_SPEECH", threshold="OFF"), types.SafetySetting( category="HARM_CATEGORY_DANGEROUS_CONTENT", threshold="OFF" ), types.SafetySetting( category="HARM_CATEGORY_SEXUALLY_EXPLICIT", threshold="OFF" ), - types.SafetySetting(category="HARM_CATEGORY_HARASSMENT", threshold="OFF"), + types.SafetySetting( + category="HARM_CATEGORY_HARASSMENT", threshold="OFF"), ], tools=tools, thinking_config=types.ThinkingConfig( @@ -48,7 +71,7 @@ def generate(prompt: str): ) for chunk in client.models.generate_content_stream( - model=model, + model=GEMINI_MODEL, contents=contents, config=generate_content_config, ):