diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 7bee7c332f..ac51f5a37a 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -1722,6 +1722,20 @@ def create_tuning_job( LOGGER.debug("tune request: %s", json.dumps(tune_request, indent=4)) self.sagemaker_client.create_hyper_parameter_tuning_job(**tune_request) + def describe_tuning_job(self, job_name): + """Calls the DescribeHyperParameterTuningJob API for the given job name + and returns the response. + + Args: + job_name (str): The name of the hyperparameter tuning job to describe. + + Returns: + dict: A dictionary response with the hyperparameter tuning job description. + """ + return self.sagemaker_client.describe_hyper_parameter_tuning_job( + HyperParameterTuningJobName=job_name + ) + @classmethod def _map_tuning_config( cls, diff --git a/src/sagemaker/tuner.py b/src/sagemaker/tuner.py index 7a74a95e24..b7fc939773 100644 --- a/src/sagemaker/tuner.py +++ b/src/sagemaker/tuner.py @@ -750,6 +750,10 @@ def stop_tuning_job(self): self._ensure_last_tuning_job() self.latest_tuning_job.stop() + def describe(self): + """Returns a response from the DescribeHyperParameterTuningJob API call.""" + return self.sagemaker_session.describe_tuning_job(self._current_job_name) + def wait(self): """Wait for latest hyperparameter tuning job to finish.""" self._ensure_last_tuning_job() diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 5aaa3f0af0..103a74af54 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -2125,3 +2125,11 @@ def test_list_candidates_for_auto_ml_job_with_optional_args(sagemaker_session): sagemaker_session.sagemaker_client.list_candidates_for_auto_ml_job.assert_called_with( **COMPLETE_EXPECTED_LIST_CANDIDATES_ARGS ) + + +def test_describe_tuning_Job(sagemaker_session): + job_name = "hyper-parameter-tuning" + sagemaker_session.describe_tuning_job(job_name=job_name) + sagemaker_session.sagemaker_client.describe_hyper_parameter_tuning_job.assert_called_with( + HyperParameterTuningJobName=job_name + ) diff --git a/tests/unit/test_tuner.py b/tests/unit/test_tuner.py index 67de14bc50..344a8ae88e 100644 --- a/tests/unit/test_tuner.py +++ b/tests/unit/test_tuner.py @@ -1423,6 +1423,11 @@ def test_create_warm_start_tuner_with_single_estimator_dict( assert tuner.warm_start_config.parents == additional_parents +def test_describe(tuner): + tuner.describe() + tuner.sagemaker_session.describe_tuning_job.assert_called_once() + + def _convert_tuning_job_details(job_details, estimator_name): """Convert a tuning job description using the 'TrainingJobDefinition' field into a new one using a single-item 'TrainingJobDefinitions' field (list).