Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,16 @@
)
from transformers.processing_utils import Unpack
from transformers.utils import TransformersKwargs, auto_docstring, logging
from transformers.utils.generic import check_model_inputs

try:
from nemo_automodel.components.models.biencoder.state_dict_adapter import BiencoderStateDictAdapter
except ImportError:
BiencoderStateDictAdapter = object

from nemo_automodel.shared.import_utils import get_check_model_inputs_decorator

logger = logging.get_logger(__name__)
check_model_inputs = get_check_model_inputs_decorator()


def contrastive_scores_and_labels(
Expand Down Expand Up @@ -177,7 +179,7 @@ def _update_causal_mask(
return attention_mask
return None

@check_model_inputs()
@check_model_inputs
@auto_docstring
def forward(
self,
Expand Down
6 changes: 4 additions & 2 deletions nemo_automodel/components/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,18 +50,20 @@
)
from transformers.processing_utils import Unpack
from transformers.utils import TransformersKwargs, can_return_tuple
from transformers.utils.generic import check_model_inputs

from nemo_automodel.components.models.common.combined_projection import (
CombinedGateUpMLP,
CombinedQKVAttentionMixin,
)
from nemo_automodel.components.models.llama.state_dict_adapter import LlamaStateDictAdapter
from nemo_automodel.components.moe.utils import BackendConfig
from nemo_automodel.shared.import_utils import get_check_model_inputs_decorator
from nemo_automodel.shared.utils import dtype_from_str

__all__ = ["build_llama_model", "LlamaForCausalLM"]

check_model_inputs = get_check_model_inputs_decorator()


class LlamaAttention(CombinedQKVAttentionMixin, nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper with combined QKV projection."""
Expand Down Expand Up @@ -279,7 +281,7 @@ def __init__(
# Initialize weights and apply final processing
self.post_init()

@check_model_inputs()
@check_model_inputs
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
Expand Down
6 changes: 4 additions & 2 deletions nemo_automodel/components/models/qwen2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,18 +50,20 @@
)
from transformers.processing_utils import Unpack
from transformers.utils import TransformersKwargs, can_return_tuple
from transformers.utils.generic import check_model_inputs

from nemo_automodel.components.models.common.combined_projection import (
CombinedGateUpMLP,
CombinedQKVAttentionMixin,
)
from nemo_automodel.components.models.qwen2.state_dict_adapter import Qwen2StateDictAdapter
from nemo_automodel.components.moe.utils import BackendConfig
from nemo_automodel.shared.import_utils import get_check_model_inputs_decorator
from nemo_automodel.shared.utils import dtype_from_str

__all__ = ["build_qwen2_model", "Qwen2ForCausalLM"]

check_model_inputs = get_check_model_inputs_decorator()


class Qwen2Attention(CombinedQKVAttentionMixin, nn.Module):
"""Multi-headed attention with combined QKV projection.
Expand Down Expand Up @@ -252,7 +254,7 @@ def __init__(
# Initialize weights and apply final processing
self.post_init()

@check_model_inputs()
@check_model_inputs
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
Expand Down
47 changes: 47 additions & 0 deletions nemo_automodel/shared/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,3 +419,50 @@ def is_te_min_version(version, check_equality=True):
if check_equality:
return get_te_version() >= PkgVersion(version)
return get_te_version() > PkgVersion(version)


def get_transformers_version():
"""Get transformers version from __version__."""
try:
import transformers

if hasattr(transformers, "__version__"):
_version = str(transformers.__version__)
else:
from importlib.metadata import version

_version = version("transformers")
except ImportError:
_version = "0.0.0"
return PkgVersion(_version)


def is_transformers_min_version(version, check_equality=True):
"""Check if minimum version of `transformers` is installed."""
if check_equality:
return get_transformers_version() >= PkgVersion(version)
return get_transformers_version() > PkgVersion(version)


def get_check_model_inputs_decorator():
"""
Get the appropriate check_model_inputs decorator based on transformers version.

In transformers >= 4.57.3, check_model_inputs became a function that returns a decorator.
In older versions, it was directly a decorator.

Returns:
Decorator function to validate model inputs.
"""
try:
from transformers.utils.generic import check_model_inputs

if is_transformers_min_version("4.57.3"):
# New API: check_model_inputs() returns a decorator
return check_model_inputs()
else:
# Old API: check_model_inputs is directly a decorator
return check_model_inputs
except ImportError:
# If transformers is not available, return a no-op decorator
return null_decorator
26 changes: 26 additions & 0 deletions tests/unit_tests/shared/test_import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,3 +174,29 @@ def test_is_te_min_version():
"""
assert si.is_te_min_version("0.0.0") is True
assert si.is_te_min_version("9999.0.0", check_equality=False) is False


def test_get_transformers_version_type():
"""
``get_transformers_version`` should *never* raise – even when transformers is unavailable
while building docs – and must always return a ``packaging.version.Version``.
"""
ver = si.get_transformers_version()
assert isinstance(ver, PkgVersion)


def test_is_transformers_min_version():
"""
* A ridiculously low requirement must be satisfied.
* A far-future version must *not* be satisfied.
"""
assert si.is_transformers_min_version("0.0.0") is True
assert si.is_transformers_min_version("9999.0.0", check_equality=False) is False


def test_get_check_model_inputs_decorator():
"""
``get_check_model_inputs_decorator`` should always return a callable decorator.
"""
decorator = si.get_check_model_inputs_decorator()
assert callable(decorator)
Loading