Skip to content

Commit f0670f9

Browse files
alex6499catdanabens
authored andcommitted
fix(tracker.py): add base_trial_component_name parameter to create function
1 parent c0a4e71 commit f0670f9

File tree

3 files changed

+15
-3
lines changed

3 files changed

+15
-3
lines changed

src/smexperiments/tracker.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,7 @@ def load(
167167
@classmethod
168168
def create(
169169
cls,
170+
base_trial_component_name="TrialComponent",
170171
display_name=None,
171172
artifact_bucket=None,
172173
artifact_prefix=None,
@@ -185,6 +186,7 @@ def create(
185186
my_tracker = tracker.Tracker.create()
186187
187188
Args:
189+
base_trial_component_name: (str,optional). The name of the trial component resource that will be appended with a timestamp. Defaults to "TrialComponent".
188190
display_name: (str, optional). The display name of the trial component to track.
189191
artifact_bucket: (str, optional) The name of the S3 bucket to store artifacts to.
190192
artifact_prefix: (str, optional) The prefix to write artifacts to within ``artifact_bucket``
@@ -201,7 +203,7 @@ def create(
201203
sagemaker_boto_client = sagemaker_boto_client or _utils.sagemaker_client()
202204

203205
tc = trial_component.TrialComponent.create(
204-
trial_component_name=_utils.name("TrialComponent"),
206+
trial_component_name=_utils.name(base_trial_component_name),
205207
display_name=display_name,
206208
sagemaker_boto_client=sagemaker_boto_client,
207209
)

tests/unit/test_tracker.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -149,13 +149,22 @@ def test_create(boto3_session, sagemaker_boto_client):
149149
trial_component_display_name = "foo-trial-component-display-name"
150150
sagemaker_boto_client.create_trial_component.return_value = {"TrialComponentName": trial_component_name}
151151
tracker_created = tracker.Tracker.create(
152-
display_name=trial_component_display_name, sagemaker_boto_client=sagemaker_boto_client
152+
base_trial_component_name="AlexName",
153+
display_name=trial_component_display_name,
154+
sagemaker_boto_client=sagemaker_boto_client,
153155
)
154156
assert trial_component_name == tracker_created.trial_component.trial_component_name
155-
157+
sagemaker_boto_client.create_trial_component.assert_called_with(
158+
DisplayName="foo-trial-component-display-name", TrialComponentName=AnyStringWith("AlexName")
159+
)
156160
assert tracker_created._metrics_writer is None
157161

158162

163+
class AnyStringWith(str):
164+
def __eq__(self, other):
165+
return self in other
166+
167+
159168
@pytest.fixture
160169
def trial_component_obj(sagemaker_boto_client):
161170
return trial_component.TrialComponent(sagemaker_boto_client)

tox.ini

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ deps =
5959
pytest
6060
pytest-cov
6161
docker
62+
scikit-learn==0.24.2
6263

6364
[testenv:flake8]
6465
basepython = python3

0 commit comments

Comments
 (0)