I am implementing a custom chain by subclassing langchain.llms.base.LLM
and overriding the _call()
method. My goal is to pass additional parameters such as max_tokens
to the _call()
method. However, the parameters I provide are not being passed to the method correctly. Here is my code:
from pydantic import Field
from typing import List, Optional
import requests
from langchain.prompts import PromptTemplate
from langchain.llms.base import LLM
from langchain.chains import LLMChain
class CustomVLLM(LLM):
endpoint: str = Field(..., description="url")
def __init__(self, endpoint: str):
super().__init__()
self.endpoint = endpoint
# def _call(self, prompt: str, max_tokens: int = 111, stop: Optional[List[str]] = "<eee>") -> str:
def _call(self, prompt, stop: Optional[List[str]] = "<stop>", **kwargs):
"""
calling vLLM generate
"""
max_tokens = kwargs.get("max_tokens", 1000)
print(f"running _call,prompt: {prompt}, max_tokens: {max_tokens}, stop: {stop}")
request_payload = {
"prompt": prompt,
"max_tokens": max_tokens,
"stop": stop,
}
try:
response = requests.post(self.endpoint, json=request_payload)
response.raise_for_status()
return response.json().get("text", "")
except Exception as e:
raise ValueError(f"Error during LLM call: {e}")
@property
def _llm_type(self) -> str:
return "custom_vllm"
def create_decomposition_chain(llm):
prompt = PromptTemplate(template=("test"))
return LLMChain(llm=llm, prompt=prompt)
if __name__ == "__main__":
llm = CustomVLLM("http://localhost:8000/generate/")
decomposition_chain = create_decomposition_chain(llm)
decomposition_result = decomposition_chain.run(question='test', max_tokens=100, stop="<end>")
sub_questions = decomposition_result.strip().split("\n")
print(sub_questions)
When running the code, the _call()
method logs the following output:
Running _call, prompt: test, max_tokens: 1000, stop: <end>
It seems the max_tokens
parameter I passed (max_tokens=100
) is not being forwarded to _call()
and is defaulting to 1000
.
My Questions:
Why is
max_tokens=10
not being passed to the_call()
method?How can I modify my code to ensure that the
max_tokens
value is correctly passed to_call()
viakwargs
?
Additional Context:
LangChain version: 0.2.17
LangChain-Core verson: 0.2.43
Python version: 3.8.19
Endpoint setup: The endpoint is a local server receiving POST requests.
Expected Behavior:
When running the chain, I expect _call()
to log max_tokens: 100
instead of max_tokens: 1000
as the default value:
Running _call, prompt: test, max_tokens: 100, stop: <end>