Files
Akern-Genai/lib.py
2026-01-29 14:40:43 +01:00

141 lines
4.4 KiB
Python

"""Google Gemini API integration for Akern-Genai project.
This module provides functionality to generate content using Google's Gemini model
with Vertex AI RAG (Retrieval-Augmented Generation) support.
"""
from llm_config import generate_content_config
import logging
import asyncio
import json
import os
import threading
from google import genai
from google.genai import types
from google.oauth2 import service_account
from dotenv import load_dotenv
logger = logging.getLogger(__name__)
# Load environment variables from .env file
load_dotenv()
def get_credentials():
"""Get Google Cloud credentials from environment.
Supports two methods:
1. GOOGLE_CREDENTIALS_JSON: Direct JSON content as string (production)
2. GOOGLE_APPLICATION_CREDENTIALS: Path to JSON file (local development)
Returns:
service_account.Credentials: The loaded credentials
"""
# OAuth scopes required for Vertex AI API
SCOPES = ["https://www.googleapis.com/auth/cloud-platform"]
# Try to load credentials from JSON content directly
credentials_json = os.getenv("GOOGLE_CREDENTIALS_JSON")
logger.info(f"creds JSON: {credentials_json}")
if credentials_json:
try:
credentials_info = json.loads(credentials_json)
logger.info(f"creds JSON parsed: {credentials_info}")
return service_account.Credentials.from_service_account_info(
credentials_info, scopes=SCOPES
)
except json.JSONDecodeError as e:
raise ValueError(f"Invalid JSON in GOOGLE_CREDENTIALS_JSON: {e}")
# Fall back to file-based credentials (standard behavior)
credentials_path = os.getenv("GOOGLE_APPLICATION_CREDENTIALS")
logger.info(f"creds path: {credentials_path}")
if credentials_path and os.path.exists(credentials_path):
return service_account.Credentials.from_service_account_file(
credentials_path, scopes=SCOPES
)
# If neither is provided, return None to let the client use default credentials
# (useful when running on Google Cloud with service account attached)
return None
# Gemini model name
GEMINI_MODEL: str = "gemini-3-pro-preview"
async def generate(prompt: str):
"""Generate content using Gemini model with RAG retrieval.
This function creates a streaming response from the Gemini model,
augmented with content from the configured RAG corpus.
The blocking API call is run in a thread pool to allow concurrent
processing of multiple WebSocket connections.
Args:
prompt: The user's input prompt to generate content for.
Yields:
str: Text chunks from the generated response.
"""
# Create a queue for streaming chunks
chunk_queue: asyncio.Queue[str] = asyncio.Queue()
loop = asyncio.get_event_loop()
def run_streaming():
"""Run the synchronous streaming in a separate thread."""
try:
credentials = get_credentials()
logger.info(f"credentials: {credentials}")
client = genai.Client(vertexai=True, credentials=credentials)
contents = [
types.Content(role="user", parts=[
types.Part.from_text(text=prompt)]),
]
for chunk in client.models.generate_content_stream(
model=GEMINI_MODEL,
contents=contents,
config=generate_content_config,
):
if (
chunk.candidates
and chunk.candidates[0].content
and chunk.candidates[0].content.parts
):
# Schedule the put operation in the event loop
future = asyncio.run_coroutine_threadsafe(
chunk_queue.put(chunk.text),
loop,
)
# Wait for the put to complete (quick operation)
future.result(timeout=1)
except Exception as e:
print(f"[ERROR] Streaming error: {e}")
finally:
asyncio.run_coroutine_threadsafe(
chunk_queue.put("<<END>>"),
loop,
)
# Start the streaming in a daemon thread
stream_thread = threading.Thread(target=run_streaming, daemon=True)
stream_thread.start()
# Yield chunks as they become available
while True:
chunk = await chunk_queue.get()
if chunk == "<<END>>":
break
yield chunk