Skip to content

huggingface: add astream for huggingface pipeline #31575

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -684,36 +684,50 @@ async def _astream(
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs, "stream": True}
if _is_huggingface_endpoint(self.llm):
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs, "stream": True}

default_chunk_class: type[BaseMessageChunk] = AIMessageChunk
default_chunk_class: type[BaseMessageChunk] = AIMessageChunk

async for chunk in await self.llm.async_client.chat_completion(
messages=message_dicts, **params
):
if len(chunk["choices"]) == 0:
continue
choice = chunk["choices"][0]
message_chunk = _convert_chunk_to_message_chunk(chunk, default_chunk_class)
generation_info = {}
if finish_reason := choice.get("finish_reason"):
generation_info["finish_reason"] = finish_reason
generation_info["model_name"] = self.model_id
logprobs = choice.get("logprobs")
if logprobs:
generation_info["logprobs"] = logprobs
default_chunk_class = message_chunk.__class__
generation_chunk = ChatGenerationChunk(
message=message_chunk, generation_info=generation_info or None
async for chunk in await self.llm.async_client.chat_completion(
messages=message_dicts, **params
):
if len(chunk["choices"]) == 0:
continue
choice = chunk["choices"][0]
message_chunk = _convert_chunk_to_message_chunk(
chunk, default_chunk_class
)
generation_info = {}
if finish_reason := choice.get("finish_reason"):
generation_info["finish_reason"] = finish_reason
generation_info["model_name"] = self.model_id
logprobs = choice.get("logprobs")
if logprobs:
generation_info["logprobs"] = logprobs
default_chunk_class = message_chunk.__class__
generation_chunk = ChatGenerationChunk(
message=message_chunk, generation_info=generation_info or None
)
if run_manager:
await run_manager.on_llm_new_token(
token=generation_chunk.text,
chunk=generation_chunk,
logprobs=logprobs,
)
yield generation_chunk
else:
llm_input = self._to_chat_prompt(messages)
stream_iter = self.llm._astream(
llm_input, stop=stop, run_manager=run_manager, **kwargs
)
if run_manager:
await run_manager.on_llm_new_token(
token=generation_chunk.text,
chunk=generation_chunk,
logprobs=logprobs,
async for chunk in stream_iter: # chunk is a GenerationChunk
chat_chunk = ChatGenerationChunk(
message=AIMessageChunk(content=chunk.text),
generation_info=chunk.generation_info,
)
yield generation_chunk
yield chat_chunk

def _to_chat_prompt(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@

import importlib.util
import logging
from collections.abc import Iterator, Mapping
from collections.abc import AsyncIterator, Iterator, Mapping
from typing import Any, Optional

from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.llms import BaseLLM
from langchain_core.outputs import Generation, GenerationChunk, LLMResult
from pydantic import ConfigDict, model_validator
Expand Down Expand Up @@ -403,3 +406,62 @@ def __call__(
run_manager.on_llm_new_token(chunk.text, chunk=chunk)

yield chunk

async def _astream(
self,
prompt: str,
stop: Optional[list[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[GenerationChunk]:
from threading import Thread

import torch
from transformers import (
AsyncTextIteratorStreamer,
StoppingCriteria,
StoppingCriteriaList,
)

pipeline_kwargs = kwargs.get("pipeline_kwargs", {})
skip_prompt = kwargs.get("skip_prompt", True)

if stop is not None:
stop = self.pipeline.tokenizer.convert_tokens_to_ids(stop)
stopping_ids_list = stop or []

class StopOnTokens(StoppingCriteria):
def __call__(
self,
input_ids: torch.LongTensor,
scores: torch.FloatTensor,
**kwargs: Any,
) -> bool:
for stop_id in stopping_ids_list:
if input_ids[0][-1] == stop_id:
return True
return False

stopping_criteria = StoppingCriteriaList([StopOnTokens()])

async_streamer = AsyncTextIteratorStreamer(
self.pipeline.tokenizer,
timeout=60.0,
skip_prompt=skip_prompt,
skip_special_tokens=True,
)
generation_kwargs = dict(
text_inputs=prompt,
streamer=async_streamer,
stopping_criteria=stopping_criteria,
**pipeline_kwargs,
)
t1 = Thread(target=self.pipeline, kwargs=generation_kwargs)
t1.start()

async for char in async_streamer:
chunk = GenerationChunk(text=char)
if run_manager:
await run_manager.on_llm_new_token(chunk.text, chunk=chunk)

yield chunk
19 changes: 18 additions & 1 deletion libs/partners/huggingface/tests/integration_tests/test_llms.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from collections.abc import Generator
from collections.abc import AsyncGenerator, Generator

from langchain_huggingface.llms import HuggingFacePipeline

Expand All @@ -18,3 +18,20 @@ def test_huggingface_pipeline_streaming() -> None:
assert isinstance(chunk, str)
stream_results_string = chunk
assert len(stream_results_string.strip()) > 0


async def test_huggingface_pipeline_astreaming() -> None:
"""Test streaming tokens from huggingface_pipeline using astream."""
llm = HuggingFacePipeline.from_model_id(
model_id="openai-community/gpt2",
task="text-generation",
pipeline_kwargs={"max_new_tokens": 10},
)
agenerator = llm.astream("Q: How do you say 'hello' in German? A:'", stop=["."])
stream_results_string = ""
assert isinstance(agenerator, AsyncGenerator)

async for chunk in agenerator:
assert isinstance(chunk, str)
stream_results_string += chunk
assert len(stream_results_string.strip()) > 0
Loading