Skip to content

Commit 9035c44

Browse files
committed
lint
1 parent cacb977 commit 9035c44

File tree

2 files changed

+37
-1
lines changed

2 files changed

+37
-1
lines changed

src/sagemaker/jumpstart/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1680,7 +1680,7 @@ def get_model_access_config(accept_eula: Optional[bool], environment: Optional[d
16801680

16811681
model_access_config = None
16821682
if env_var_eula is not None:
1683-
model_access_config = {"AcceptEula": True if env_var_eula == "true" else False}
1683+
model_access_config = {"AcceptEula": env_var_eula == "true"}
16841684
if accept_eula is not None:
16851685
model_access_config = {"AcceptEula": accept_eula}
16861686

src/test.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
from sagemaker.jumpstart.hub.hub import Hub
2+
from sagemaker import hyperparameters
3+
from sagemaker.session import Session
4+
from sagemaker.jumpstart.estimator import JumpStartEstimator
5+
6+
7+
hub = Hub(hub_name="temp-bencrab-hub", sagemaker_session=Session())
8+
9+
# hub.create(description="hello haha")
10+
11+
model_id = "meta-vlm-llama-3-2-11b-vision"
12+
model_version = "*"
13+
hub_arn = hub.hub_name
14+
15+
my_hyperparameters = hyperparameters.retrieve_default(
16+
model_id=model_id, model_version=model_version, hub_arn=hub_arn
17+
)
18+
print(my_hyperparameters)
19+
hyperparameters.validate(
20+
model_id=model_id,
21+
model_version=model_version,
22+
hyperparameters=my_hyperparameters,
23+
hub_arn=hub_arn,
24+
)
25+
estimator = JumpStartEstimator(
26+
model_id=model_id,
27+
hub_name=hub_arn,
28+
model_version=model_version,
29+
environment={"accept_eula": "true"}, # Please change {"accept_eula": "true"}
30+
disable_output_compression=True,
31+
instance_type="ml.p4d.24xlarge",
32+
hyperparameters=my_hyperparameters,
33+
)
34+
estimator.fit(
35+
{"training": "s3://jumpstart-cache-prod-us-west-2/training-datasets/docVQA-small-3000ex/"}
36+
)

0 commit comments

Comments
 (0)