Skip to content

Commit 611ea9a

Browse files
beniericpintaoz-aws
authored andcommitted
Fix: codestyles (#1606)
1 parent 1a3330a commit 611ea9a

File tree

3 files changed

+14
-12
lines changed

3 files changed

+14
-12
lines changed

src/sagemaker/image_uris.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import logging
1818
import os
1919
import re
20-
from typing import Optional, Tuple
20+
from typing import Optional
2121
from packaging.version import Version
2222

2323
from sagemaker import utils
@@ -463,6 +463,7 @@ def _get_latest_versions(list_of_versions):
463463
print("SORT")
464464
return sorted(list_of_versions, reverse=True)[0]
465465

466+
466467
def _get_latest_version(framework, version, image_scope):
467468
"""Get the latest version from the input framework"""
468469
if version:
@@ -479,6 +480,7 @@ def _get_latest_version(framework, version, image_scope):
479480
version = _fetch_latest_version_from_config(framework_config, image_scope)
480481
return version
481482

483+
482484
def _validate_accelerator_type(accelerator_type):
483485
"""Raises a ``ValueError`` if ``accelerator_type`` is invalid."""
484486
if not accelerator_type.startswith("ml.eia") and accelerator_type != "local_sagemaker_notebook":
@@ -748,7 +750,7 @@ def get_base_python_image_uri(region, py_version="310") -> str:
748750
return ECR_URI_TEMPLATE.format(registry=registry, hostname=hostname, repository=repo_and_tag)
749751

750752

751-
def _fetch_latest_version_from_config(
753+
def _fetch_latest_version_from_config( # pylint: disable=R0911
752754
framework_config: dict, image_scope: Optional[str] = None
753755
) -> Optional[str]:
754756
"""Helper function to fetch the latest version as a string from a framework's config

src/sagemaker/serve/builder/model_builder.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -477,10 +477,9 @@ def _get_client_translators(self):
477477

478478
return serializer, deserializer
479479

480-
def _get_predictor(self,
481-
endpoint_name: str,
482-
sagemaker_session: Session,
483-
component_name: Optional[str] = None) -> Predictor:
480+
def _get_predictor(
481+
self, endpoint_name: str, sagemaker_session: Session, component_name: Optional[str] = None
482+
) -> Predictor:
484483
"""Placeholder docstring"""
485484
serializer, deserializer = self._get_client_translators()
486485

@@ -489,7 +488,7 @@ def _get_predictor(self,
489488
sagemaker_session=sagemaker_session,
490489
serializer=serializer,
491490
deserializer=deserializer,
492-
component_name=component_name
491+
component_name=component_name,
493492
)
494493

495494
def _create_model(self):

tests/unit/sagemaker/image_uris/test_retrieve.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
import logging
1717

1818
import pytest
19-
from contextlib import nullcontext
2019
from mock import patch
2120

2221
from sagemaker import image_uris
@@ -715,8 +714,8 @@ def test_retrieve_huggingface(config_for_framework):
715714
container_version="cu110-ubuntu18.04",
716715
)
717716
assert (
718-
"564829616587.dkr.ecr.us-east-1.amazonaws.com/huggingface-pytorch-training:"
719-
"1.6.0-transformers4.3.1-gpu-py37-cu110-ubuntu18.04" == pt_new_version
717+
"564829616587.dkr.ecr.us-east-1.amazonaws.com/huggingface-pytorch-training:"
718+
"1.6.0-transformers4.3.1-gpu-py37-cu110-ubuntu18.04" == pt_new_version
720719
)
721720

722721

@@ -769,6 +768,7 @@ def test_retrieve_with_pipeline_variable():
769768
image_scope="training",
770769
)
771770

771+
772772
@patch("sagemaker.image_uris.config_for_framework")
773773
def test_get_latest_version_function_with_invalid_framework(config_for_framework):
774774
config_for_framework.side_effect = FileNotFoundError
@@ -777,6 +777,7 @@ def test_get_latest_version_function_with_invalid_framework(config_for_framework
777777
image_uris.retrieve("xgboost", "inference")
778778
assert "No framework config for framework" in str(e.exception)
779779

780+
780781
@patch("sagemaker.image_uris.config_for_framework")
781782
def test_get_latest_version_function_with_no_framework(config_for_framework):
782783
config_for_framework.side_effect = {}
@@ -785,6 +786,7 @@ def test_get_latest_version_function_with_no_framework(config_for_framework):
785786
image_uris.retrieve("xgboost", "inference")
786787
assert "No framework config for framework" in str(e.exception)
787788

789+
788790
@pytest.mark.parametrize(
789791
"framework",
790792
[
@@ -854,7 +856,6 @@ def test_get_latest_version_function_with_no_framework(config_for_framework):
854856
"sagemaker-base-python",
855857
],
856858
)
857-
858859
@patch("sagemaker.image_uris.config_for_framework")
859860
@patch("sagemaker.image_uris.retrieve")
860861
def test_retrieve_with_parameterized(mock_image_retrieve, mock_config_for_framework, framework):
@@ -867,4 +868,4 @@ def test_retrieve_with_parameterized(mock_image_retrieve, mock_config_for_framew
867868
image_scope="inference",
868869
)
869870
except ValueError as e:
870-
pytest.fail(e.value)
871+
pytest.fail(e.value)

0 commit comments

Comments
 (0)