diff --git a/app.py b/app.py index fd0c23d..b7f065b 100644 --- a/app.py +++ b/app.py @@ -91,7 +91,7 @@ async def websocket_endpoint(websocket: WebSocket): while True: data = await websocket.receive_text() - async for chunk in generate(data): + for chunk in generate(data): await websocket.send_text(chunk) await websocket.send_text("<>") diff --git a/lib.py b/lib.py index 4683941..06ebaf1 100644 --- a/lib.py +++ b/lib.py @@ -4,8 +4,6 @@ This module provides functionality to generate content using Google's Gemini mod with Vertex AI RAG (Retrieval-Augmented Generation) support. """ -import asyncio -from concurrent.futures import ThreadPoolExecutor from google import genai from google.genai import types from dotenv import load_dotenv @@ -19,15 +17,15 @@ CORPUS: str = "projects/520464122471/locations/europe-west3/ragCorpora/230584300 # Gemini model name GEMINI_MODEL: str = "gemini-3-pro-preview" -# Thread pool for blocking API calls -_executor = ThreadPoolExecutor(max_workers=10) +def generate(prompt: str): + """Generate content using Gemini model with RAG retrieval. -def _generate_sync(prompt: str): - """Synchronous wrapper for generate_content_stream. + This function creates a streaming response from the Gemini model, + augmented with content from the configured RAG corpus. - This function contains the blocking Google GenAI SDK call. - It should be run in a thread pool to avoid blocking the event loop. + The blocking API call is run in a thread pool to allow concurrent + processing of multiple WebSocket connections. Args: prompt: The user's input prompt to generate content for. @@ -86,28 +84,3 @@ def _generate_sync(prompt: str): continue yield chunk.text - - -async def generate(prompt: str): - """Generate content using Gemini model with RAG retrieval. - - This function creates a streaming response from the Gemini model, - augmented with content from the configured RAG corpus. - - The blocking API call is run in a thread pool to allow concurrent - processing of multiple WebSocket connections. - - Args: - prompt: The user's input prompt to generate content for. - - Yields: - str: Text chunks from the generated response. - """ - loop = asyncio.get_running_loop() - - # Run the synchronous generator in a thread pool to avoid blocking the event loop - sync_gen = await loop.run_in_executor(_executor, _generate_sync, prompt) - - # Yield from the synchronous generator - for chunk in sync_gen: - yield chunk