86 lines
2.6 KiB
Python
86 lines
2.6 KiB
Python
"""Google Gemini API integration for Akern-Genai project.
|
|
|
|
This module provides functionality to generate content using Google's Gemini model
|
|
with Vertex AI RAG (Retrieval-Augmented Generation) support.
|
|
"""
|
|
|
|
import asyncio
|
|
import threading
|
|
from google import genai
|
|
from google.genai import types
|
|
from dotenv import load_dotenv
|
|
|
|
from llm_config import generate_content_config
|
|
|
|
# Load environment variables from .env file
|
|
load_dotenv()
|
|
|
|
|
|
# Gemini model name
|
|
GEMINI_MODEL: str = "gemini-3-pro-preview"
|
|
|
|
|
|
async def generate(prompt: str):
|
|
"""Generate content using Gemini model with RAG retrieval.
|
|
|
|
This function creates a streaming response from the Gemini model,
|
|
augmented with content from the configured RAG corpus.
|
|
|
|
The blocking API call is run in a thread pool to allow concurrent
|
|
processing of multiple WebSocket connections.
|
|
|
|
Args:
|
|
prompt: The user's input prompt to generate content for.
|
|
|
|
Yields:
|
|
str: Text chunks from the generated response.
|
|
"""
|
|
# Create a queue for streaming chunks
|
|
chunk_queue: asyncio.Queue[str] = asyncio.Queue()
|
|
loop = asyncio.get_event_loop()
|
|
|
|
def run_streaming():
|
|
"""Run the synchronous streaming in a separate thread."""
|
|
try:
|
|
client = genai.Client(vertexai=True)
|
|
|
|
contents = [
|
|
types.Content(role="user", parts=[types.Part.from_text(text=prompt)]),
|
|
]
|
|
|
|
for chunk in client.models.generate_content_stream(
|
|
model=GEMINI_MODEL,
|
|
contents=contents,
|
|
config=generate_content_config,
|
|
):
|
|
if (
|
|
chunk.candidates
|
|
and chunk.candidates[0].content
|
|
and chunk.candidates[0].content.parts
|
|
):
|
|
# Schedule the put operation in the event loop
|
|
future = asyncio.run_coroutine_threadsafe(
|
|
chunk_queue.put(chunk.text),
|
|
loop,
|
|
)
|
|
# Wait for the put to complete (quick operation)
|
|
future.result(timeout=1)
|
|
except Exception as e:
|
|
print(f"[ERROR] Streaming error: {e}")
|
|
finally:
|
|
asyncio.run_coroutine_threadsafe(
|
|
chunk_queue.put("<<END>>"),
|
|
loop,
|
|
)
|
|
|
|
# Start the streaming in a daemon thread
|
|
stream_thread = threading.Thread(target=run_streaming, daemon=True)
|
|
stream_thread.start()
|
|
|
|
# Yield chunks as they become available
|
|
while True:
|
|
chunk = await chunk_queue.get()
|
|
if chunk == "<<END>>":
|
|
break
|
|
yield chunk
|