add db
This commit is contained in:
11
chain.py
11
chain.py
@@ -45,6 +45,7 @@ class RagChain:
|
||||
|
||||
self._retriever_sources: list[dict] = []
|
||||
self._reranked_sources: list[dict] = []
|
||||
self._rephrased_question: str = ""
|
||||
|
||||
self._llm = ChatGoogleGenerativeAI(
|
||||
model=MODEL,
|
||||
@@ -80,7 +81,7 @@ class RagChain:
|
||||
| question_rewrite_prompt
|
||||
| self._llm
|
||||
| StrOutputParser()
|
||||
| RunnableLambda(self._log_rewritten_question)
|
||||
| RunnableLambda(self._log_rephrased_question)
|
||||
)
|
||||
|
||||
rag_chain = (
|
||||
@@ -95,8 +96,9 @@ class RagChain:
|
||||
|
||||
self._full_chain = question_rewrite_chain | rag_chain
|
||||
|
||||
def _log_rewritten_question(self, rewritten_question: str) -> str:
|
||||
return rewritten_question
|
||||
def _log_rephrased_question(self, rephrased_question: str) -> str:
|
||||
self._rephrased_question = rephrased_question
|
||||
return rephrased_question
|
||||
|
||||
def _format_docs(self, question: str) -> str:
|
||||
retrieved_docs = self._base_retriever.invoke(question)
|
||||
@@ -150,5 +152,8 @@ class RagChain:
|
||||
def getRankedSources(self) -> list[dict]:
|
||||
return list(self._reranked_sources)
|
||||
|
||||
def getRephrasedQuestion(self) -> str:
|
||||
return self._rephrased_question
|
||||
|
||||
def stream(self, message: str):
|
||||
return self._full_chain.astream(message)
|
||||
|
||||
Reference in New Issue
Block a user