Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions src/transformers/cli/add_new_model_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,6 @@

import typer

from ..models.auto.configuration_auto import CONFIG_MAPPING_NAMES, MODEL_NAMES_MAPPING
from ..models.auto.feature_extraction_auto import FEATURE_EXTRACTOR_MAPPING_NAMES
from ..models.auto.image_processing_auto import IMAGE_PROCESSOR_MAPPING_NAMES
from ..models.auto.processing_auto import PROCESSOR_MAPPING_NAMES
from ..models.auto.tokenization_auto import TOKENIZER_MAPPING_NAMES
from ..models.auto.video_processing_auto import VIDEO_PROCESSOR_MAPPING_NAMES
from ..utils import is_libcst_available
from .add_fast_image_processor import add_fast_image_processor

Expand Down Expand Up @@ -128,6 +122,13 @@ class ModelInfos:
"""

def __init__(self, lowercase_name: str):
from ..models.auto.configuration_auto import CONFIG_MAPPING_NAMES, MODEL_NAMES_MAPPING
from ..models.auto.feature_extraction_auto import FEATURE_EXTRACTOR_MAPPING_NAMES
from ..models.auto.image_processing_auto import IMAGE_PROCESSOR_MAPPING_NAMES
from ..models.auto.processing_auto import PROCESSOR_MAPPING_NAMES
from ..models.auto.tokenization_auto import TOKENIZER_MAPPING_NAMES
from ..models.auto.video_processing_auto import VIDEO_PROCESSOR_MAPPING_NAMES

# Just to make sure it's indeed lowercase
self.lowercase_name = lowercase_name.lower().replace(" ", "_").replace("-", "_")
if self.lowercase_name not in CONFIG_MAPPING_NAMES:
Expand Down Expand Up @@ -676,6 +677,8 @@ def get_user_input():
"""
Ask the user for the necessary inputs to add the new model.
"""
from transformers.models.auto.configuration_auto import MODEL_NAMES_MAPPING

model_types = list(MODEL_NAMES_MAPPING.keys())

# Get old model type
Expand Down
97 changes: 59 additions & 38 deletions src/transformers/cli/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import asyncio
import base64
import copy
Expand All @@ -29,19 +31,15 @@
from contextlib import asynccontextmanager
from io import BytesIO
from threading import Thread
from typing import Annotated, Optional, TypedDict, Union
from typing import TYPE_CHECKING, Annotated, Optional, TypedDict, Union

import typer
from huggingface_hub import model_info
from huggingface_hub.constants import HF_HUB_OFFLINE
from openai.types.chat.chat_completion import Choice
from tokenizers.decoders import DecodeStream

import transformers
from transformers.models.auto.modeling_auto import (
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES,
)
from transformers import BitsAndBytesConfig, GenerationConfig
from transformers.utils.import_utils import (
is_fastapi_available,
is_librosa_available,
Expand All @@ -52,26 +50,21 @@
)

from .. import (
AutoConfig,
LogitsProcessorList,
PreTrainedTokenizerFast,
ProcessorMixin,
TextIteratorStreamer,
)
from ..utils import is_torch_available, logging

from ..utils import logging

if is_torch_available():
import torch

if TYPE_CHECKING:
from transformers import (
AutoProcessor,
BitsAndBytesConfig,
GenerationConfig,
PreTrainedModel,
PreTrainedTokenizerFast,
ProcessorMixin,
)

from ..generation.continuous_batching import ContinuousBatchingManager, RequestStatus
from ..generation.continuous_batching import ContinuousBatchingManager


if is_librosa_available():
import librosa
Expand All @@ -90,6 +83,7 @@
from openai.types.audio.transcription import Transcription
from openai.types.audio.transcription_create_params import TranscriptionCreateParamsBase
from openai.types.chat import ChatCompletion, ChatCompletionMessage, ChatCompletionMessageParam
from openai.types.chat.chat_completion import Choice
from openai.types.chat.chat_completion_chunk import (
ChatCompletionChunk,
ChoiceDelta,
Expand Down Expand Up @@ -215,6 +209,25 @@ class TransformersTranscriptionCreateParams(TranscriptionCreateParamsBase, total
X_REQUEST_ID = "x-request-id"


def set_torch_seed(_seed):
import torch

torch.manual_seed(_seed)


def reset_torch_cache():
import torch

if torch.cuda.is_available():
torch.cuda.empty_cache()


def torch_ones_like(_input_tensor):
import torch

return torch.ones_like(_input_tensor)


class Modality(enum.Enum):
LLM = "LLM"
VLM = "VLM"
Expand All @@ -224,9 +237,9 @@ class Modality(enum.Enum):

def create_generation_config_from_req(
req: dict,
model_generation_config: "GenerationConfig",
model_generation_config: GenerationConfig,
**kwargs,
) -> "GenerationConfig":
) -> GenerationConfig:
"""
Creates a generation config from the parameters of the request. If a generation config is passed in the request,
it will be used as a baseline for parameterization. Otherwise, we will use the model's default generation config.
Expand Down Expand Up @@ -276,7 +289,7 @@ def create_generation_config_from_req(
if req.get("top_p") is not None:
generation_config.top_p = float(req["top_p"])
if req.get("seed") is not None:
torch.manual_seed(req["seed"])
set_torch_seed(req["seed"])

return generation_config

Expand All @@ -303,9 +316,9 @@ class TimedModel:

def __init__(
self,
model: "PreTrainedModel",
model: PreTrainedModel,
timeout_seconds: int,
processor: Union["ProcessorMixin", "PreTrainedTokenizerFast"] | None = None,
processor: Union[ProcessorMixin, PreTrainedTokenizerFast] | None = None,
):
self.model = model
self._name_or_path = str(model.name_or_path)
Expand All @@ -330,8 +343,7 @@ def delete_model(self):
gc.collect()

# Clear CUDA cache if available
if torch.cuda.is_available():
torch.cuda.empty_cache()
reset_torch_cache()

# XXX: in case we manually delete the model, like on server shutdown
self._timer.cancel()
Expand Down Expand Up @@ -433,7 +445,7 @@ def __init__(

# Seed
if default_seed is not None:
torch.manual_seed(default_seed)
set_torch_seed(default_seed)

# Set up logging
transformers_logger = logging.get_logger("transformers")
Expand Down Expand Up @@ -462,7 +474,7 @@ def __init__(
self.load_model_and_processor(model_id_and_revision)

@asynccontextmanager
async def lifespan(app: "FastAPI"):
async def lifespan(app: FastAPI):
yield
for model in self.loaded_models.values():
model.delete_model()
Expand Down Expand Up @@ -576,7 +588,7 @@ def _validate_request(
self,
request: dict,
schema: TypedDict,
validator: "TypeAdapter",
validator: TypeAdapter,
unused_fields: set,
):
"""
Expand Down Expand Up @@ -652,7 +664,7 @@ def build_chat_completion_chunk(
model: str | None = None,
role: str | None = None,
finish_reason: str | None = None,
tool_calls: list["ChoiceDeltaToolCall"] | None = None,
tool_calls: list[ChoiceDeltaToolCall] | None = None,
decode_stream: DecodeStream | None = None,
tokenizer: PreTrainedTokenizerFast | None = None,
) -> ChatCompletionChunk:
Expand Down Expand Up @@ -816,6 +828,8 @@ def continuous_batching_chat_completion(self, req: dict, request_id: str) -> Str
)["input_ids"][0]

def stream_chat_completion(request_id, decode_stream):
from ..generation.continuous_batching import RequestStatus

try:
# Emit the assistant role to start the stream. Other chunks won't have a role, as it is implicit
# they come from the assistant.
Expand Down Expand Up @@ -900,7 +914,12 @@ def cancellation_wrapper_buffer(_request_id):
return JSONResponse(json_chunk, media_type="application/json")

@staticmethod
def get_model_modality(model: "PreTrainedModel") -> Modality:
def get_model_modality(model: PreTrainedModel) -> Modality:
from transformers.models.auto.modeling_auto import (
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES,
)

model_classname = model.__class__.__name__
if model_classname in MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES.values():
modality = Modality.VLM
Expand Down Expand Up @@ -1241,7 +1260,7 @@ def generate_response(self, req: dict) -> Generator[str, None, None]:
inputs = inputs.to(model.device)
request_id = req.get("previous_response_id", "req_0")

# Temporary hack for GPTOSS 1: don't filter special tokens
# Temporary hack for GPT-OSS 1: don't filter special tokens
skip_special_tokens = True
if "gptoss" in model.config.architectures[0].lower():
skip_special_tokens = False
Expand All @@ -1261,15 +1280,15 @@ def generate_response(self, req: dict) -> Generator[str, None, None]:

generation_kwargs = {
"inputs": inputs,
"attention_mask": torch.ones_like(inputs),
"attention_mask": torch_ones_like(inputs),
"streamer": generation_streamer,
"generation_config": generation_config,
"return_dict_in_generate": True,
"past_key_values": last_kv_cache,
}

def stream_response(streamer, _request_id):
# Temporary hack for GPTOS 2: filter out the CoT tokens. Full solution here implies defining new output
# Temporary hack for GPT-OSS 2: filter out the CoT tokens. Full solution here implies defining new output
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Temporary hack for GPT-OSS 2: filter out the CoT tokens. Full solution here implies defining new output
# Temporary hack for GPT-OSS: filter out the CoT tokens. Full solution here implies defining new output

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's fix 2 out of 3 for the GPT OSS family of models :)

# classes and piping the reasoning trace into a new field
filter_cot = False
cot_trace_end = None
Expand Down Expand Up @@ -1560,7 +1579,7 @@ def generate_response_non_streaming(self, req: dict) -> dict:

generate_output = model.generate(
inputs=inputs,
attention_mask=torch.ones_like(inputs),
attention_mask=torch_ones_like(inputs),
generation_config=generation_config,
return_dict_in_generate=True,
past_key_values=last_kv_cache,
Expand Down Expand Up @@ -1674,7 +1693,7 @@ def is_continuation(self, req: dict) -> bool:
self.last_messages = messages
return req_continues_last_messages

def get_quantization_config(self) -> Optional["BitsAndBytesConfig"]:
def get_quantization_config(self) -> Optional[BitsAndBytesConfig]:
"""
Returns the quantization config for the given CLI arguments.

Expand Down Expand Up @@ -1729,6 +1748,10 @@ def _load_model_and_data_processor(self, model_id_and_revision: str):
`tuple[PreTrainedModel, Union[ProcessorMixin, PreTrainedTokenizerFast]]`: The loaded model and
data processor (tokenizer, audio processor, etc.).
"""
import torch

from transformers import AutoConfig, AutoProcessor

logger.info(f"Loading {model_id_and_revision}")

if "@" in model_id_and_revision:
Expand Down Expand Up @@ -1769,9 +1792,7 @@ def _load_model_and_data_processor(self, model_id_and_revision: str):
logger.info(f"Loaded model {model_id_and_revision}")
return model, data_processor

def load_model_and_processor(
self, model_id_and_revision: str
) -> tuple["PreTrainedModel", PreTrainedTokenizerFast]:
def load_model_and_processor(self, model_id_and_revision: str) -> tuple[PreTrainedModel, PreTrainedTokenizerFast]:
"""
Loads the text model and processor from the given model ID and revision into the ServeCommand instance.

Expand All @@ -1796,7 +1817,7 @@ def load_model_and_processor(

return model, processor

def load_audio_model_and_processor(self, model_id_and_revision: str) -> tuple["PreTrainedModel", ProcessorMixin]:
def load_audio_model_and_processor(self, model_id_and_revision: str) -> tuple[PreTrainedModel, ProcessorMixin]:
"""
Loads the audio model and processor from the given model ID and revision into the ServeCommand instance.

Expand Down
2 changes: 0 additions & 2 deletions src/transformers/cli/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from transformers.cli.add_new_model_like import add_new_model_like
from transformers.cli.chat import Chat, ChatCommand
from transformers.cli.download import download
from transformers.cli.run import run
from transformers.cli.serve import Serve
from transformers.cli.system import env, version

Expand All @@ -31,7 +30,6 @@
app.command(name="chat", cls=ChatCommand)(Chat)
app.command()(download)
app.command()(env)
app.command()(run)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this intentional? If yes, can the run.py be entirely removed since it's not used anymore?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And if removed, would be good to mention it in PR description + migration guide.

Apart from that the PR looks good to me (haven't run it locally though)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done! Thanks @Wauplin

app.command(name="serve")(Serve)
app.command()(version)

Expand Down
13 changes: 9 additions & 4 deletions src/transformers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from types import ModuleType
from typing import Any

import packaging.version
from packaging import version

from . import logging
Expand Down Expand Up @@ -92,10 +93,14 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> tuple[

@lru_cache
def is_torch_available() -> bool:
is_available, torch_version = _is_package_available("torch", return_version=True)
if is_available and version.parse(torch_version) < version.parse("2.2.0"):
logger.warning_once(f"Disabling PyTorch because PyTorch >= 2.2 is required but found {torch_version}")
return is_available and version.parse(torch_version) >= version.parse("2.2.0")
try:
is_available, torch_version = _is_package_available("torch", return_version=True)
parsed_version = version.parse(torch_version)
if is_available and parsed_version < version.parse("2.2.0"):
logger.warning_once(f"Disabling PyTorch because PyTorch >= 2.2 is required but found {torch_version}")
return is_available and version.parse(torch_version) >= version.parse("2.2.0")
except packaging.version.InvalidVersion:
return False
Comment on lines +96 to +103
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Happy to revert this; I ran into an edge-case when uninstalling torch where it would still show up with is_available yet would have 'N/A' as a version.

This is a bit out of scope of that PR so happy to revert it

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just saw that, isn't that caused by the lru_cache maybe?



@lru_cache
Expand Down