fix async with threads
This commit is contained in:
98
lib.py
98
lib.py
@@ -4,17 +4,17 @@ This module provides functionality to generate content using Google's Gemini mod
|
||||
with Vertex AI RAG (Retrieval-Augmented Generation) support.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import threading
|
||||
from google import genai
|
||||
from google.genai import types
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from llm_config import generate_content_config
|
||||
|
||||
# Load environment variables from .env file
|
||||
load_dotenv()
|
||||
|
||||
# Vertex AI RAG Corpus resource path
|
||||
CORPUS: str = (
|
||||
"projects/520464122471/locations/europe-west3/ragCorpora/2305843009213693952"
|
||||
)
|
||||
|
||||
# Gemini model name
|
||||
GEMINI_MODEL: str = "gemini-3-pro-preview"
|
||||
@@ -35,53 +35,51 @@ async def generate(prompt: str):
|
||||
Yields:
|
||||
str: Text chunks from the generated response.
|
||||
"""
|
||||
client = genai.Client(vertexai=True)
|
||||
# Create a queue for streaming chunks
|
||||
chunk_queue: asyncio.Queue[str] = asyncio.Queue()
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
contents = [
|
||||
types.Content(role="user", parts=[types.Part.from_text(text=prompt)]),
|
||||
]
|
||||
tools = [
|
||||
types.Tool(
|
||||
retrieval=types.Retrieval(
|
||||
vertex_rag_store=types.VertexRagStore(
|
||||
rag_resources=[types.VertexRagStoreRagResource(rag_corpus=CORPUS)],
|
||||
)
|
||||
def run_streaming():
|
||||
"""Run the synchronous streaming in a separate thread."""
|
||||
try:
|
||||
client = genai.Client(vertexai=True)
|
||||
|
||||
contents = [
|
||||
types.Content(role="user", parts=[types.Part.from_text(text=prompt)]),
|
||||
]
|
||||
|
||||
for chunk in client.models.generate_content_stream(
|
||||
model=GEMINI_MODEL,
|
||||
contents=contents,
|
||||
config=generate_content_config,
|
||||
):
|
||||
if (
|
||||
chunk.candidates
|
||||
and chunk.candidates[0].content
|
||||
and chunk.candidates[0].content.parts
|
||||
):
|
||||
# Schedule the put operation in the event loop
|
||||
future = asyncio.run_coroutine_threadsafe(
|
||||
chunk_queue.put(chunk.text),
|
||||
loop,
|
||||
)
|
||||
# Wait for the put to complete (quick operation)
|
||||
future.result(timeout=1)
|
||||
except Exception as e:
|
||||
print(f"[ERROR] Streaming error: {e}")
|
||||
finally:
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
chunk_queue.put("<<END>>"),
|
||||
loop,
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
generate_content_config = types.GenerateContentConfig(
|
||||
temperature=1,
|
||||
top_p=0.95,
|
||||
max_output_tokens=65535,
|
||||
safety_settings=[
|
||||
types.SafetySetting(category="HARM_CATEGORY_HATE_SPEECH", threshold="OFF"),
|
||||
types.SafetySetting(
|
||||
category="HARM_CATEGORY_DANGEROUS_CONTENT", threshold="OFF"
|
||||
),
|
||||
types.SafetySetting(
|
||||
category="HARM_CATEGORY_SEXUALLY_EXPLICIT", threshold="OFF"
|
||||
),
|
||||
types.SafetySetting(category="HARM_CATEGORY_HARASSMENT", threshold="OFF"),
|
||||
],
|
||||
tools=tools,
|
||||
thinking_config=types.ThinkingConfig(
|
||||
thinking_level="HIGH",
|
||||
),
|
||||
)
|
||||
# Start the streaming in a daemon thread
|
||||
stream_thread = threading.Thread(target=run_streaming, daemon=True)
|
||||
stream_thread.start()
|
||||
|
||||
for chunk in client.models.generate_content_stream(
|
||||
model=GEMINI_MODEL,
|
||||
contents=contents,
|
||||
config=generate_content_config,
|
||||
):
|
||||
# DEBUG: Log chunk type to confirm generator behavior
|
||||
print(f"[DEBUG] Chunk type: {type(chunk)}")
|
||||
if (
|
||||
not chunk.candidates
|
||||
or not chunk.candidates[0].content
|
||||
or not chunk.candidates[0].content.parts
|
||||
):
|
||||
continue
|
||||
|
||||
yield chunk.text
|
||||
# Yield chunks as they become available
|
||||
while True:
|
||||
chunk = await chunk_queue.get()
|
||||
if chunk == "<<END>>":
|
||||
break
|
||||
yield chunk
|
||||
|
||||
Reference in New Issue
Block a user