一、在langchain框架下实现LLM流式响应
在上篇文章中,我们已经实现了使用langchain的LLMChain完成ChatGLM的调用,主要是重写一个ChatGLM类,并让ChatGLM继承与LLM,然后重写_call()方法,虽然我们在_call方法里面调用了ChatGLM的stream_chat方法让大语言模型实现了流式响应,但是langchain并没有按流式响应,将结果返回到我们自己调用LLMChain的方法中,现在我们就要想办法完成这一需求
网上有很多人说,直接调用LLMChain的stream方法就行,实际情况是在不做任何重写的情况是绝对不可能实现流式相应的
既然langchain是开源的,那么还是重读源码开始....
1.1、langchain源码结构梳理
首先,我们需要确定入口点,按照网上的朋友说的直接调用stream,那么我们就从stream入手
PROMPT = PromptTemplate(template=prompt_template,input_variables=["text"])
chain = Stream_Chain(llm=self.llm,prompt=PROMPT)
content = chain.stream(text=text)
F12跟进去看看stream方法
def stream(
self,
input: Input,
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> Iterator[Output]:
"""
Default implementation of stream, which calls invoke.
Subclasses should override this method if they support streaming output.
"""
yield self.invoke(input, config, **kwargs)
注意,这里发现yield self.invoke(input, config,**kwargs),yield都出现了,好像确实是在采用流式相应...继续看看invoke方法
@abstractmethod
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output:
...
这里发现这根本就是一个抽象方法,好吧,去找子类看看,目前所在类:Runnable,他的子类是:
class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC):
那我们在Chain里面找找invoke吧
def invoke(
self,
input: Dict[str, Any],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Dict[str, Any]:
config = config or {}
return self(
input,
callbacks=config.get("callbacks"),
tags=config.get("tags"),
metadata=config.get("metadata"),
run_name=config.get("run_name"),
**kwargs,
)
熟悉python的朋友都知道,这个方法直接调用了self(),这意味着是在调用该类的__call__方法
def __call__(
self,
inputs: Union[Dict[str, Any], Any],
return_only_outputs: bool = False,
callbacks: Callbacks = None,
*,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
run_name: Optional[str] = None,
include_run_info: bool = False,
) -> Dict[str, Any]:
inputs = self.prep_inputs(inputs)
callback_manager = CallbackManager.configure(
callbacks,
self.callbacks,
self.verbose,
tags,
self.tags,
metadata,
self.metadata,
)
new_arg_supported = inspect.signature(self._call).parameters.get("run_manager")
run_manager = callback_manager.on_chain_start(
dumpd(self),
inputs,
文章来源
发表评论