-
Notifications
You must be signed in to change notification settings - Fork 17
[WIP] feat: implement streaming methods for chat models #197
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -1,15 +1,15 @@ | ||||||||||||||||||||||||||||||||||||||||||||||||||||
import json | ||||||||||||||||||||||||||||||||||||||||||||||||||||
import logging | ||||||||||||||||||||||||||||||||||||||||||||||||||||
from typing import Any, Dict, List, Literal, Optional, Union | ||||||||||||||||||||||||||||||||||||||||||||||||||||
from typing import Any, AsyncIterator, Dict, Iterator, List, Literal, Optional, Union | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
from langchain_core.callbacks import ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||
AsyncCallbackManagerForLLMRun, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
CallbackManagerForLLMRun, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
from langchain_core.language_models import LanguageModelInput | ||||||||||||||||||||||||||||||||||||||||||||||||||||
from langchain_core.messages import AIMessage, BaseMessage | ||||||||||||||||||||||||||||||||||||||||||||||||||||
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage | ||||||||||||||||||||||||||||||||||||||||||||||||||||
from langchain_core.messages.ai import UsageMetadata | ||||||||||||||||||||||||||||||||||||||||||||||||||||
from langchain_core.outputs import ChatGeneration, ChatResult | ||||||||||||||||||||||||||||||||||||||||||||||||||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult | ||||||||||||||||||||||||||||||||||||||||||||||||||||
from langchain_core.runnables import Runnable | ||||||||||||||||||||||||||||||||||||||||||||||||||||
from langchain_openai.chat_models import AzureChatOpenAI | ||||||||||||||||||||||||||||||||||||||||||||||||||||
from pydantic import BaseModel | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -49,6 +49,54 @@ async def _agenerate( | |||||||||||||||||||||||||||||||||||||||||||||||||||
response = await self._acall(self.url, payload, self.auth_headers) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
return self._create_chat_result(response) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
def _stream( | ||||||||||||||||||||||||||||||||||||||||||||||||||||
self, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
messages: List[BaseMessage], | ||||||||||||||||||||||||||||||||||||||||||||||||||||
stop: Optional[List[str]] = None, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
run_manager: Optional[CallbackManagerForLLMRun] = None, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
**kwargs: Any, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
) -> Iterator[ChatGenerationChunk]: | ||||||||||||||||||||||||||||||||||||||||||||||||||||
if "tools" in kwargs and not kwargs["tools"]: | ||||||||||||||||||||||||||||||||||||||||||||||||||||
del kwargs["tools"] | ||||||||||||||||||||||||||||||||||||||||||||||||||||
payload = self._get_request_payload(messages, stop=stop, **kwargs) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
response = self._call(self.url, payload, self.auth_headers) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
# For non-streaming response, yield single chunk | ||||||||||||||||||||||||||||||||||||||||||||||||||||
chat_result = self._create_chat_result(response) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
chunk = ChatGenerationChunk( | ||||||||||||||||||||||||||||||||||||||||||||||||||||
message=AIMessageChunk( | ||||||||||||||||||||||||||||||||||||||||||||||||||||
content=chat_result.generations[0].message.content, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
additional_kwargs=chat_result.generations[0].message.additional_kwargs, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
response_metadata=chat_result.generations[0].message.response_metadata, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
usage_metadata=chat_result.generations[0].message.usage_metadata, # type: ignore | ||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
yield chunk | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
async def _astream( | ||||||||||||||||||||||||||||||||||||||||||||||||||||
self, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
messages: List[BaseMessage], | ||||||||||||||||||||||||||||||||||||||||||||||||||||
stop: Optional[List[str]] = None, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
**kwargs: Any, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
) -> AsyncIterator[ChatGenerationChunk]: | ||||||||||||||||||||||||||||||||||||||||||||||||||||
if "tools" in kwargs and not kwargs["tools"]: | ||||||||||||||||||||||||||||||||||||||||||||||||||||
del kwargs["tools"] | ||||||||||||||||||||||||||||||||||||||||||||||||||||
payload = self._get_request_payload(messages, stop=stop, **kwargs) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
response = await self._acall(self.url, payload, self.auth_headers) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
# For non-streaming response, yield single chunk | ||||||||||||||||||||||||||||||||||||||||||||||||||||
chat_result = self._create_chat_result(response) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
chunk = ChatGenerationChunk( | ||||||||||||||||||||||||||||||||||||||||||||||||||||
message=AIMessageChunk( | ||||||||||||||||||||||||||||||||||||||||||||||||||||
content=chat_result.generations[0].message.content, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
additional_kwargs=chat_result.generations[0].message.additional_kwargs, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
response_metadata=chat_result.generations[0].message.response_metadata, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
usage_metadata=chat_result.generations[0].message.usage_metadata, # type: ignore | ||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
yield chunk | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+86
to
+99
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The async streaming implementation also returns a single chunk instead of true streaming. This duplicates the same non-streaming behavior as the sync version.
Suggested change
Copilot uses AI. Check for mistakes. Positive FeedbackNegative Feedback |
||||||||||||||||||||||||||||||||||||||||||||||||||||
def with_structured_output( | ||||||||||||||||||||||||||||||||||||||||||||||||||||
self, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
schema: Optional[Any] = None, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -217,6 +265,92 @@ async def _agenerate( | |||||||||||||||||||||||||||||||||||||||||||||||||||
response = await self._acall(self.url, payload, self.auth_headers) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
return self._create_chat_result(response) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
def _stream( | ||||||||||||||||||||||||||||||||||||||||||||||||||||
self, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
messages: List[BaseMessage], | ||||||||||||||||||||||||||||||||||||||||||||||||||||
stop: Optional[List[str]] = None, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
run_manager: Optional[CallbackManagerForLLMRun] = None, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
**kwargs: Any, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
) -> Iterator[ChatGenerationChunk]: | ||||||||||||||||||||||||||||||||||||||||||||||||||||
"""Stream the LLM on a given prompt. | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
Args: | ||||||||||||||||||||||||||||||||||||||||||||||||||||
messages: the prompt composed of a list of messages. | ||||||||||||||||||||||||||||||||||||||||||||||||||||
stop: a list of strings on which the model should stop generating. | ||||||||||||||||||||||||||||||||||||||||||||||||||||
run_manager: A run manager with callbacks for the LLM. | ||||||||||||||||||||||||||||||||||||||||||||||||||||
**kwargs: Additional keyword arguments. | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
Returns: | ||||||||||||||||||||||||||||||||||||||||||||||||||||
An iterator of ChatGenerationChunk objects. | ||||||||||||||||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||
if kwargs.get("tools"): | ||||||||||||||||||||||||||||||||||||||||||||||||||||
kwargs["tools"] = [tool["function"] for tool in kwargs["tools"]] | ||||||||||||||||||||||||||||||||||||||||||||||||||||
if "tool_choice" in kwargs and kwargs["tool_choice"]["type"] == "function": | ||||||||||||||||||||||||||||||||||||||||||||||||||||
kwargs["tool_choice"] = { | ||||||||||||||||||||||||||||||||||||||||||||||||||||
"type": "tool", | ||||||||||||||||||||||||||||||||||||||||||||||||||||
"name": kwargs["tool_choice"]["function"]["name"], | ||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||
payload = self._get_request_payload(messages, stop=stop, **kwargs) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
response = self._call(self.url, payload, self.auth_headers) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
# For non-streaming response, yield single chunk | ||||||||||||||||||||||||||||||||||||||||||||||||||||
chat_result = self._create_chat_result(response) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
chunk = ChatGenerationChunk( | ||||||||||||||||||||||||||||||||||||||||||||||||||||
message=AIMessageChunk( | ||||||||||||||||||||||||||||||||||||||||||||||||||||
content=chat_result.generations[0].message.content, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
additional_kwargs=chat_result.generations[0].message.additional_kwargs, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
response_metadata=chat_result.generations[0].message.response_metadata, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
usage_metadata=chat_result.generations[0].message.usage_metadata, # type: ignore | ||||||||||||||||||||||||||||||||||||||||||||||||||||
tool_calls=getattr( | ||||||||||||||||||||||||||||||||||||||||||||||||||||
chat_result.generations[0].message, "tool_calls", None | ||||||||||||||||||||||||||||||||||||||||||||||||||||
), | ||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
yield chunk | ||||||||||||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+294
to
+309
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is the third instance of duplicated non-streaming logic in streaming methods. The code pattern is repeated across multiple methods with only minor variations. Copilot uses AI. Check for mistakes. Positive FeedbackNegative Feedback |
||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
async def _astream( | ||||||||||||||||||||||||||||||||||||||||||||||||||||
self, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
messages: List[BaseMessage], | ||||||||||||||||||||||||||||||||||||||||||||||||||||
stop: Optional[List[str]] = None, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
**kwargs: Any, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
) -> AsyncIterator[ChatGenerationChunk]: | ||||||||||||||||||||||||||||||||||||||||||||||||||||
"""Async stream the LLM on a given prompt. | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
Args: | ||||||||||||||||||||||||||||||||||||||||||||||||||||
messages: the prompt composed of a list of messages. | ||||||||||||||||||||||||||||||||||||||||||||||||||||
stop: a list of strings on which the model should stop generating. | ||||||||||||||||||||||||||||||||||||||||||||||||||||
run_manager: A run manager with callbacks for the LLM. | ||||||||||||||||||||||||||||||||||||||||||||||||||||
**kwargs: Additional keyword arguments. | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
Returns: | ||||||||||||||||||||||||||||||||||||||||||||||||||||
An async iterator of ChatGenerationChunk objects. | ||||||||||||||||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||
if kwargs.get("tools"): | ||||||||||||||||||||||||||||||||||||||||||||||||||||
kwargs["tools"] = [tool["function"] for tool in kwargs["tools"]] | ||||||||||||||||||||||||||||||||||||||||||||||||||||
if "tool_choice" in kwargs and kwargs["tool_choice"]["type"] == "function": | ||||||||||||||||||||||||||||||||||||||||||||||||||||
kwargs["tool_choice"] = { | ||||||||||||||||||||||||||||||||||||||||||||||||||||
"type": "tool", | ||||||||||||||||||||||||||||||||||||||||||||||||||||
"name": kwargs["tool_choice"]["function"]["name"], | ||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||
payload = self._get_request_payload(messages, stop=stop, **kwargs) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
response = await self._acall(self.url, payload, self.auth_headers) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
# For non-streaming response, yield single chunk | ||||||||||||||||||||||||||||||||||||||||||||||||||||
chat_result = self._create_chat_result(response) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
chunk = ChatGenerationChunk( | ||||||||||||||||||||||||||||||||||||||||||||||||||||
message=AIMessageChunk( | ||||||||||||||||||||||||||||||||||||||||||||||||||||
content=chat_result.generations[0].message.content, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
additional_kwargs=chat_result.generations[0].message.additional_kwargs, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
response_metadata=chat_result.generations[0].message.response_metadata, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
usage_metadata=chat_result.generations[0].message.usage_metadata, # type: ignore | ||||||||||||||||||||||||||||||||||||||||||||||||||||
tool_calls=getattr( | ||||||||||||||||||||||||||||||||||||||||||||||||||||
chat_result.generations[0].message, "tool_calls", None | ||||||||||||||||||||||||||||||||||||||||||||||||||||
), | ||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||
yield chunk | ||||||||||||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+337
to
+352
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fourth instance of the same duplicated non-streaming implementation. Consider extracting this chunk creation logic into a shared helper method to reduce code duplication. Copilot uses AI. Check for mistakes. Positive FeedbackNegative Feedback |
||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||
def with_structured_output( | ||||||||||||||||||||||||||||||||||||||||||||||||||||
self, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
schema: Optional[Any] = None, | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The streaming implementation returns a single chunk instead of true streaming. Consider implementing actual streaming by making a streaming request to the API or clearly document this as a fallback implementation.
Copilot uses AI. Check for mistakes.