Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 17 additions & 6 deletions src/transformers/commands/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,20 +24,26 @@
from threading import Thread
from typing import Optional

import torch
import yaml
from rich.console import Console
from rich.live import Live
from rich.markdown import Markdown

from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer
from transformers.utils import is_rich_available, is_torch_available

from . import BaseTransformersCLICommand


if platform.system() != "Windows":
import pwd

if is_rich_available():
from rich.console import Console
from rich.live import Live
from rich.markdown import Markdown

if is_torch_available():
import torch

from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer


HELP_STRING = """\

Expand Down Expand Up @@ -153,7 +159,7 @@ def parse_settings(user_input, current_args, interface):
return current_args, True


def get_quantization_config(model_args) -> Optional[BitsAndBytesConfig]:
def get_quantization_config(model_args) -> Optional["BitsAndBytesConfig"]:
if model_args.load_in_4bit:
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
Expand Down Expand Up @@ -433,6 +439,11 @@ def __init__(self, args):
self.args = args

def run(self):
if not is_rich_available():
raise ImportError("You need to install rich to use the chat interface. (`pip install rich`)")
if not is_torch_available():
raise ImportError("You need to install torch to use the chat interface. (`pip install torch`)")

args = self.args
if args.examples_path is None:
examples = DEFAULT_EXAMPLES
Expand Down
1 change: 1 addition & 0 deletions src/transformers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@
is_pytesseract_available,
is_pytest_available,
is_pytorch_quantization_available,
is_rich_available,
is_rjieba_available,
is_sacremoses_available,
is_safetensors_available,
Expand Down
12 changes: 11 additions & 1 deletion src/transformers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,6 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[
_gptqmodel_available = _is_package_available("gptqmodel")
# `importlib.metadata.version` doesn't work with `awq`
_auto_awq_available = importlib.util.find_spec("awq") is not None
_quanto_available = _is_package_available("quanto")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(unused var)

_is_optimum_quanto_available = False
try:
importlib.metadata.version("optimum_quanto")
Expand Down Expand Up @@ -203,6 +202,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[
_liger_kernel_available = _is_package_available("liger_kernel")
_triton_available = _is_package_available("triton")
_spqr_available = _is_package_available("spqr_quant")
_rich_available = _is_package_available("rich")

_torch_version = "N/A"
_torch_available = False
Expand Down Expand Up @@ -1300,6 +1300,10 @@ def is_triton_available():
return _triton_available


def is_rich_available():
return _rich_available


# docstyle-ignore
AV_IMPORT_ERROR = """
{0} requires the PyAv library but it was not found in your environment. You can install it with:
Expand Down Expand Up @@ -1659,6 +1663,11 @@ def is_triton_available():
jinja2`. Please note that you may need to restart your runtime after installation.
"""

RICH_IMPORT_ERROR = """
{0} requires the rich library but it was not found in your environment. You can install it with pip: `pip install
rich`. Please note that you may need to restart your runtime after installation.
"""

BACKENDS_MAPPING = OrderedDict(
[
("av", (is_av_available, AV_IMPORT_ERROR)),
Expand Down Expand Up @@ -1705,6 +1714,7 @@ def is_triton_available():
("peft", (is_peft_available, PEFT_IMPORT_ERROR)),
("jinja", (is_jinja_available, JINJA_IMPORT_ERROR)),
("yt_dlp", (is_yt_dlp_available, YT_DLP_IMPORT_ERROR)),
("rich", (is_rich_available, RICH_IMPORT_ERROR)),
]
)

Expand Down