124 lines
4.0 KiB
Python
124 lines
4.0 KiB
Python
"""Google Gemini API integration for Akern-Genai project.
|
|
|
|
This module provides functionality to generate content using Google's Gemini model
|
|
with Vertex AI RAG (Retrieval-Augmented Generation) support.
|
|
"""
|
|
|
|
import asyncio
|
|
import json
|
|
import os
|
|
import threading
|
|
from google import genai
|
|
from google.genai import types
|
|
from google.oauth2 import service_account
|
|
from dotenv import load_dotenv
|
|
|
|
from llm_config import generate_content_config
|
|
|
|
# Load environment variables from .env file
|
|
load_dotenv()
|
|
|
|
|
|
def get_credentials():
|
|
"""Get Google Cloud credentials from environment.
|
|
|
|
Supports two methods:
|
|
1. GOOGLE_CREDENTIALS_JSON: Direct JSON content as string (production)
|
|
2. GOOGLE_APPLICATION_CREDENTIALS: Path to JSON file (local development)
|
|
|
|
Returns:
|
|
service_account.Credentials: The loaded credentials
|
|
"""
|
|
# Try to load credentials from JSON content directly
|
|
credentials_json = os.getenv("GOOGLE_CREDENTIALS_JSON")
|
|
if credentials_json:
|
|
try:
|
|
credentials_info = json.loads(credentials_json)
|
|
return service_account.Credentials.from_service_account_info(
|
|
credentials_info
|
|
)
|
|
except json.JSONDecodeError as e:
|
|
raise ValueError(f"Invalid JSON in GOOGLE_CREDENTIALS_JSON: {e}")
|
|
|
|
# Fall back to file-based credentials (standard behavior)
|
|
credentials_path = os.getenv("GOOGLE_APPLICATION_CREDENTIALS")
|
|
if credentials_path and os.path.exists(credentials_path):
|
|
return service_account.Credentials.from_service_account_file(
|
|
credentials_path
|
|
)
|
|
|
|
# If neither is provided, return None to let the client use default credentials
|
|
# (useful when running on Google Cloud with service account attached)
|
|
return None
|
|
|
|
|
|
# Gemini model name
|
|
GEMINI_MODEL: str = "gemini-3-pro-preview"
|
|
|
|
|
|
async def generate(prompt: str):
|
|
"""Generate content using Gemini model with RAG retrieval.
|
|
|
|
This function creates a streaming response from the Gemini model,
|
|
augmented with content from the configured RAG corpus.
|
|
|
|
The blocking API call is run in a thread pool to allow concurrent
|
|
processing of multiple WebSocket connections.
|
|
|
|
Args:
|
|
prompt: The user's input prompt to generate content for.
|
|
|
|
Yields:
|
|
str: Text chunks from the generated response.
|
|
"""
|
|
# Create a queue for streaming chunks
|
|
chunk_queue: asyncio.Queue[str] = asyncio.Queue()
|
|
loop = asyncio.get_event_loop()
|
|
|
|
def run_streaming():
|
|
"""Run the synchronous streaming in a separate thread."""
|
|
try:
|
|
credentials = get_credentials()
|
|
client = genai.Client(vertexai=True, credentials=credentials)
|
|
|
|
contents = [
|
|
types.Content(role="user", parts=[
|
|
types.Part.from_text(text=prompt)]),
|
|
]
|
|
|
|
for chunk in client.models.generate_content_stream(
|
|
model=GEMINI_MODEL,
|
|
contents=contents,
|
|
config=generate_content_config,
|
|
):
|
|
if (
|
|
chunk.candidates
|
|
and chunk.candidates[0].content
|
|
and chunk.candidates[0].content.parts
|
|
):
|
|
# Schedule the put operation in the event loop
|
|
future = asyncio.run_coroutine_threadsafe(
|
|
chunk_queue.put(chunk.text),
|
|
loop,
|
|
)
|
|
# Wait for the put to complete (quick operation)
|
|
future.result(timeout=1)
|
|
except Exception as e:
|
|
print(f"[ERROR] Streaming error: {e}")
|
|
finally:
|
|
asyncio.run_coroutine_threadsafe(
|
|
chunk_queue.put("<<END>>"),
|
|
loop,
|
|
)
|
|
|
|
# Start the streaming in a daemon thread
|
|
stream_thread = threading.Thread(target=run_streaming, daemon=True)
|
|
stream_thread.start()
|
|
|
|
# Yield chunks as they become available
|
|
while True:
|
|
chunk = await chunk_queue.get()
|
|
if chunk == "<<END>>":
|
|
break
|
|
yield chunk
|