diff --git a/lib.py b/lib.py index 75a4e80..4683941 100644 --- a/lib.py +++ b/lib.py @@ -4,6 +4,8 @@ This module provides functionality to generate content using Google's Gemini mod with Vertex AI RAG (Retrieval-Augmented Generation) support. """ +import asyncio +from concurrent.futures import ThreadPoolExecutor from google import genai from google.genai import types from dotenv import load_dotenv @@ -17,12 +19,15 @@ CORPUS: str = "projects/520464122471/locations/europe-west3/ragCorpora/230584300 # Gemini model name GEMINI_MODEL: str = "gemini-3-pro-preview" +# Thread pool for blocking API calls +_executor = ThreadPoolExecutor(max_workers=10) -async def generate(prompt: str): - """Generate content using Gemini model with RAG retrieval. - This function creates a streaming response from the Gemini model, - augmented with content from the configured RAG corpus. +def _generate_sync(prompt: str): + """Synchronous wrapper for generate_content_stream. + + This function contains the blocking Google GenAI SDK call. + It should be run in a thread pool to avoid blocking the event loop. Args: prompt: The user's input prompt to generate content for. @@ -30,9 +35,7 @@ async def generate(prompt: str): Yields: str: Text chunks from the generated response. """ - client = genai.Client( - vertexai=True, - ) + client = genai.Client(vertexai=True) contents = [ types.Content(role="user", parts=[types.Part.from_text(text=prompt)]), @@ -70,7 +73,7 @@ async def generate(prompt: str): ), ) - async for chunk in client.models.generate_content_stream( + for chunk in client.models.generate_content_stream( model=GEMINI_MODEL, contents=contents, config=generate_content_config, @@ -85,6 +88,26 @@ async def generate(prompt: str): yield chunk.text -if __name__ == "__main__": - for chunk in generate("Come si calcola il rapporto sodio potassio?"): - print(chunk, end="") +async def generate(prompt: str): + """Generate content using Gemini model with RAG retrieval. + + This function creates a streaming response from the Gemini model, + augmented with content from the configured RAG corpus. + + The blocking API call is run in a thread pool to allow concurrent + processing of multiple WebSocket connections. + + Args: + prompt: The user's input prompt to generate content for. + + Yields: + str: Text chunks from the generated response. + """ + loop = asyncio.get_running_loop() + + # Run the synchronous generator in a thread pool to avoid blocking the event loop + sync_gen = await loop.run_in_executor(_executor, _generate_sync, prompt) + + # Yield from the synchronous generator + for chunk in sync_gen: + yield chunk