def initialize_llm(
save_dir: str,
max_new_tokens: int = 500,
temperature: float = 0.1,
repetition_penalty: float = 1.2,
top_p: float = 0.95,
do_sample: bool = True
):
"""
Initializes a retrieval-optimized LLM for product ID extraction.
"""
try:
logger.info(f"Initializing retrieval LLM from: {save_dir}")
if not os.path.exists(save_dir):
raise FileNotFoundError(f"Model directory not found: {save_dir}")
# Load tokenizer
logger.info("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(save_dir)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
logger.info("Tokenizer loaded with pad_token set")
# Load model
logger.info("Loading model...")
model = AutoModelForCausalLM.from_pretrained(
save_dir,
device_map="auto",
torch_dtype=torch.float16,
low_cpu_mem_usage=True
)
# Check model context length
if hasattr(model.config, "max_position_embeddings"):
logger.info(f"Model context length: {model.config.max_position_embeddings}")
# Configure text-generation pipeline
logger.info("Configuring text-generation pipeline")
llama_pipeline = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
repetition_penalty=repetition_penalty,
do_sample=do_sample,
pad_token_id=tokenizer.eos_token_id,
truncation=True, # Not enforced
)
# Wrap in LangChain pipeline
logger.info("Creating HuggingFacePipeline")
hf_pipeline = HuggingFacePipeline(
pipeline=llama_pipeline,
model_kwargs={"temperature": temperature}
)
print("Actual pipeline max input:", hf_pipeline.pipeline.model.config.max_position_embeddings)
return HuggingFacePipeline(
pipeline=llama_pipeline,
model_kwargs={"temperature": temperature}
)
except Exception as e:
logger.error(f"LLM initialization failed: {str(e)}", exc_info=True)
raise
I have initialized the LLM model with this function, its llama 3.2 1b instruct
Now I have a product retrieval function
import os
import time
import re
from langchain.prompts import PromptTemplate
from langchain.chains import RetrievalQA
from utils.utils import load_config, initialize_embeddings, initialize_llm, load_faiss_store
from logger.logger import get_logger
# Initialize logger
logger = get_logger(__name__)
# Define Map-Reduce Prompts
map_prompt = PromptTemplate(
input_variables=["context", "question"],
template="""
You have the following chunk of data (could be a product or service):
{context}
User question: {question}
- Summarize any relevant items here, referencing ID and name.
- If nothing is relevant, say so.
""",
)
reduce_prompt = PromptTemplate(
input_variables=["summaries", "question"],
template="""
We have partial answers from multiple chunks:
{summaries}
Combine them into a single, cohesive answer to: "{question}"
Requirements:
1) Start with a short summary referencing relevant products or services (by ID and name).
2) Provide bullet points referencing IDs and names.
3) If no relevant items are found, say "No relevant items found."
""",
)
def semantic_search_tool(query: str) -> str:
"""
Enhanced product search that utilizes the LLM's answer and extracts product IDs.
"""
max_retries = 3
for attempt in range(max_retries):
try:
response = qa_chain.invoke(query)
llm_answer = response.get("result", "")
source_docs = response.get("source_documents", [])
if not source_docs:
logger.info("No source documents retrieved from QA chain.")
return llm_answer if llm_answer else "No matching products found."
seen_ids = set()
product_ids = []
for doc in source_docs:
pid = str(doc.metadata.get("product_id", "")).strip()
if pid and pid not in seen_ids:
seen_ids.add(pid)
product_ids.append(pid)
# Construct the final response
if product_ids:
final_response = f"{llm_answer}\n\nMatching Product IDs:\n" + "\n".join(product_ids)
else:
final_response = llm_answer if llm_answer else "No matching products found."
return final_response
except Exception as e:
logger.error(f"Attempt {attempt+1} failed: {str(e)}", exc_info=True)
if "429" in str(e):
sleep_time = 2 ** attempt
time.sleep(sleep_time)
else:
break
return "Error processing request after multiple attempts."
def main():
try:
# Configuration setup
this_dir = os.path.dirname(os.path.abspath(__file__))
config_path = os.path.join(this_dir, "..", "config.yaml")
config = load_config(config_path)
# Initialize LLM
llm = initialize_llm(
config["save_dir"],
max_new_tokens=500,
temperature=0.1,
repetition_penalty=1.3
)
# Initialize Embeddings
embeddings = initialize_embeddings(config["output_dir"])
# Load FAISS Store
faiss_store = load_faiss_store(config["product_store_path"], embeddings)
# Verify FAISS Store Content
num_vectors = faiss_store.index.ntotal
if num_vectors == 0:
logger.warning("FAISS store is empty. Ensure documents are indexed correctly.")
return
# Configure retriever with expanded search
retriever = faiss_store.as_retriever(
search_kwargs={"k": 10} # Reduce to 10 for better relevance
)
# Initialize QA chain with Map-Reduce
global qa_chain
qa_chain = RetrievalQA.from_chain_type(
llm=llm,
chain_type="map_reduce",
retriever=retriever,
return_source_documents=True,
chain_type_kwargs={
"question_prompt": map_prompt,
"combine_prompt": reduce_prompt,
# Pass additional kwargs to handle larger context
"verbose": True, # Optional: for more detailed logging
}
)
# User input for query
query = input("Enter your product search query: ")
result = semantic_search_tool(query)
print("\n" + result + "\n")
except Exception as e:
logger.error(f"Main execution failed: {str(e)}", exc_info=True)
raise
if __name__ == "__main__":
main()
I hope the code details are understandable.
$ python -m mlscripts.product_retrieval
2025-01-29 18:47:33,016 - utils - INFO - Configuration loaded from D:\Anand\Jstore_Ai\usecase1\mlscripts\..\config.yaml
2025-01-29 18:47:33,016 - utils - INFO - Initializing retrieval LLM from: output/BGI-llama
2025-01-29 18:47:33,016 - utils - INFO - Loading tokenizer...
2025-01-29 18:47:33,391 - utils - INFO - Tokenizer loaded with pad_token set
2025-01-29 18:47:33,391 - utils - INFO - Loading model...
2025-01-29 18:47:34,694 - utils - INFO - Model context length: 131072
2025-01-29 18:47:34,694 - utils - INFO - Configuring text-generation pipeline
Device set to use cuda:0
2025-01-29 18:47:34,695 - utils - INFO - Creating HuggingFacePipeline
Actual pipeline max input: 131072
2025-01-29 18:47:35,996 - utils - INFO - Embeddings initialized using model at: output/sbert_finetuned on device: cuda
2025-01-29 18:47:35,997 - utils - INFO - Loading FAISS vector store from: output/product_vector_store
2025-01-29 18:47:36,282 - utils - INFO - FAISS vector store loaded successfully.
Enter your product search query: show me some products under 2000
Token indices sequence length is longer than the specified maximum sequence length for this model (6944 > 1024). Running this sequence through the model will result in indexing errors
2025-01-29 18:48:41,656 - __main__ - ERROR - Attempt 1 failed: A single document was longer than the context length, we cannot handle this.
Traceback (most recent call last):
File "D:\Anand\Jstore_Ai\usecase1\mlscripts\product_retrieval.py", line 48, in semantic_search_tool
response = qa_chain.invoke(query)
^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\visionary\AppData\Local\miniconda3\envs\usecase1\Lib\site-packages\langchain\chains\base.py", line 170, in invoke
raise e
File "C:\Users\visionary\AppData\Local\miniconda3\envs\usecase1\Lib\site-packages\langchain\chains\base.py", line 160, in invoke
self._call(inputs, run_manager=run_manager)
File "C:\Users\visionary\AppData\Local\miniconda3\envs\usecase1\Lib\site-packages\langchain\chains\retrieval_qa\base.py", line 154, in _call
answer = selfbine_documents_chain.run(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\visionary\AppData\Local\miniconda3\envs\usecase1\Lib\site-packages\langchain_core\_api\deprecation.py", line 181, in warning_emitting_wrapper
return wrapped(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\visionary\AppData\Local\miniconda3\envs\usecase1\Lib\site-packages\langchain\chains\base.py", line 611, in run
return self(kwargs, callbacks=callbacks, tags=tags, metadata=metadata)[
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\visionary\AppData\Local\miniconda3\envs\usecase1\Lib\site-packages\langchain_core\_api\deprecation.py", line 181, in warning_emitting_wrapper
return wrapped(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\visionary\AppData\Local\miniconda3\envs\usecase1\Lib\site-packages\langchain\chains\base.py", line 389, in __call__
return self.invoke(
^^^^^^^^^^^^
File "C:\Users\visionary\AppData\Local\miniconda3\envs\usecase1\Lib\site-packages\langchain\chains\base.py", line 170, in invoke
raise e
File "C:\Users\visionary\AppData\Local\miniconda3\envs\usecase1\Lib\site-packages\langchain\chains\base.py", line 160, in invoke
self._call(inputs, run_manager=run_manager)
File "C:\Users\visionary\AppData\Local\miniconda3\envs\usecase1\Lib\site-packages\langchain\chains\combine_documents\base.py", line 138, in _call
output, extra_return_dict = selfbine_docs(
^^^^^^^^^^^^^^^^^^
File "C:\Users\visionary\AppData\Local\miniconda3\envs\usecase1\Lib\site-packages\langchain\chains\combine_documents\map_reduce.py", line 251, in combine_docs
result, extra_return_dict = self.reduce_documents_chainbine_docs(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\visionary\AppData\Local\miniconda3\envs\usecase1\Lib\site-packages\langchain\chains\combine_documents\reduce.py", line 252, in combine_docs
result_docs, extra_return_dict = self._collapse(
^^^^^^^^^^^^^^^
File "C:\Users\visionary\AppData\Local\miniconda3\envs\usecase1\Lib\site-packages\langchain\chains\combine_documents\reduce.py", line 307, in _collapse
new_result_doc_list = split_list_of_docs(
^^^^^^^^^^^^^^^^^^^
File "C:\Users\visionary\AppData\Local\miniconda3\envs\usecase1\Lib\site-packages\langchain\chains\combine_documents\reduce.py", line 51, in split_list_of_docs
raise ValueError(
ValueError: A single document was longer than the context length, we cannot handle this.
Error processing request after multiple attempts.
I am trying to retrieve the products which come under price 2000.
I have used faiss as my vector store and set up a retrievalqa with langchain and llama.