I have a working RAG code, using Langchain and Milvus. Now I'd like to add the feature to look at the metadata of each of the extracted k documents, and do the following:
- find the paragraph_id of each of the k documents in the vector store collection
- get all the documents with that same pragraph_id from the vector store collection (I guess need to access the same vector store collection again?)
- merge the text field of the occurances of each of the paragraph_ids to a single text (per paragraph_id). For the order of the merging, use the metadata value of "chunk_index" (ascending order, first is 0)
- use these texts to generate the response, not just the text of the original k documents.
Is this the right way to go? The issue is that when creating the collection, I have to split it to Documents that are usually smaller than the paragraph they're from.
The original, working and relevant code looks like this:
prompt = PromptTemplate(input_variables=["context", "question", "chat_history"], template=session_data.template)
rag_chain_from_docs = (
RunnablePassthrough.assign(
context=(lambda x: format_docs(x["context"])),
chat_history=(lambda x: formatted_chat_history)
)
| prompt
| hf
| StrOutputParser()
)
rag_chain = RunnableParallel(
{"context": session_data.retriever, "question": RunnablePassthrough()}
).assign(answer=rag_chain_from_docs)
result = rag_chain.invoke(question)
But the code I created that "takes out" the session_data.retriever
piece and makes a second round to the vector-store collection, is not working properly. Namely, I see that the system does not find the correct the documents from the vector-store any more (it finds the last one, I think), and that the "merging" is not done.
Therefore, the reply I get from the system is repeatedly "sorry, I dont know the answer"
.
The snippet I have at the moment is (notice the commented steps):
# Step 1: Retrieve the initial context (top-k documents)
relevant_context = session_data.retriever.get_relevant_documents(question)
from collections import defaultdict
grouped_sections = defaultdict(list)
# Step 2: Identify all unique section_ids from the retrieved documents
section_ids = set()
for element in relevant_context:
metadata = element.metadata
section_id = metadata.get("section_id")
if section_id is not None:
section_ids.add(section_id)
# Step 3: Query the vector store for all chunks with the same section_id
all_chunks = []
for section_id in section_ids:
try:
search_results = session_data.retriever.get_relevant_documents(
query="",
search_kwargs={"filter": {"section_id": section_id}}
)
all_chunks.extend(search_results)
except Exception as e:
logger.error(f"Error querying retriever for section_id {section_id}: {e}")
# Step 4: Group the retrieved chunks by section_id and sort by chunk_index
for chunk in all_chunks:
metadata = chunk.metadata
section_id = metadata.get("section_id")
chunk_index = metadata.get("chunk_index", 0) # Default to 0 if chunk_index is missing
text = chunk.page_content
if section_id is not None:
grouped_sections[section_id].append((chunk_index, text))
# Step 5: Merge text for each section_id and log the merged results
merged_sections = {}
for section_id, chunks in grouped_sections.items():
sorted_chunks = sorted(chunks, key=lambda x: x[0])
merged_text = " ".join(chunk[1] for chunk in sorted_chunks)
merged_sections[section_id] = merged_text
logger.info(f"Section {section_id}: {merged_text}")
# Step 6: Use merged sections to generate the response
logger.info("Generating response using merged sections.")
merged_context = "\n".join(merged_sections.values())
prompt = PromptTemplate(
input_variables=["context", "question", "chat_history"],
template=session_data.template
)
prompt_with_merged_context = prompt.format(
context=merged_context,
question=question,
chat_history=formatted_chat_history
)