"""Google Gemini API integration for Akern-Genai project. This module provides functionality to generate content using Google's Gemini model with Vertex AI RAG (Retrieval-Augmented Generation) support. """ from llm_config import generate_content_config import logging import asyncio import json import os import threading from google import genai from google.genai import types from google.oauth2 import service_account from dotenv import load_dotenv logger = logging.getLogger(__name__) # Load environment variables from .env file load_dotenv() def get_credentials(): """Get Google Cloud credentials from environment. Supports two methods: 1. GOOGLE_CREDENTIALS_JSON: Direct JSON content as string (production) 2. GOOGLE_APPLICATION_CREDENTIALS: Path to JSON file (local development) Returns: service_account.Credentials: The loaded credentials """ # OAuth scopes required for Vertex AI API SCOPES = ["https://www.googleapis.com/auth/cloud-platform"] # Try to load credentials from JSON content directly credentials_json = os.getenv("GOOGLE_CREDENTIALS_JSON") logger.info(f"creds JSON: {credentials_json}") if credentials_json: try: credentials_info = json.loads(credentials_json) logger.info(f"creds JSON parsed: {credentials_info}") return service_account.Credentials.from_service_account_info( credentials_info, scopes=SCOPES ) except json.JSONDecodeError as e: raise ValueError(f"Invalid JSON in GOOGLE_CREDENTIALS_JSON: {e}") # Fall back to file-based credentials (standard behavior) credentials_path = os.getenv("GOOGLE_APPLICATION_CREDENTIALS") logger.info(f"creds path: {credentials_path}") if credentials_path and os.path.exists(credentials_path): return service_account.Credentials.from_service_account_file( credentials_path, scopes=SCOPES ) # If neither is provided, return None to let the client use default credentials # (useful when running on Google Cloud with service account attached) return None # Gemini model name GEMINI_MODEL: str = "gemini-3-pro-preview" async def generate(prompt: str): """Generate content using Gemini model with RAG retrieval. This function creates a streaming response from the Gemini model, augmented with content from the configured RAG corpus. The blocking API call is run in a thread pool to allow concurrent processing of multiple WebSocket connections. Args: prompt: The user's input prompt to generate content for. Yields: str: Text chunks from the generated response. """ # Create a queue for streaming chunks chunk_queue: asyncio.Queue[str] = asyncio.Queue() loop = asyncio.get_event_loop() def run_streaming(): """Run the synchronous streaming in a separate thread.""" try: credentials = get_credentials() logger.info(f"credentials: {credentials}") client = genai.Client(vertexai=True, credentials=credentials) contents = [ types.Content(role="user", parts=[ types.Part.from_text(text=prompt)]), ] for chunk in client.models.generate_content_stream( model=GEMINI_MODEL, contents=contents, config=generate_content_config, ): if ( chunk.candidates and chunk.candidates[0].content and chunk.candidates[0].content.parts ): # Schedule the put operation in the event loop future = asyncio.run_coroutine_threadsafe( chunk_queue.put(chunk.text), loop, ) # Wait for the put to complete (quick operation) future.result(timeout=1) except Exception as e: print(f"[ERROR] Streaming error: {e}") finally: asyncio.run_coroutine_threadsafe( chunk_queue.put("<>"), loop, ) # Start the streaming in a daemon thread stream_thread = threading.Thread(target=run_streaming, daemon=True) stream_thread.start() # Yield chunks as they become available while True: chunk = await chunk_queue.get() if chunk == "<>": break yield chunk