Skip to content

Commit 57010a3

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
chore: Update custom job psci sample to use python aiplatform SDK
PiperOrigin-RevId: 776332474
1 parent 0c612fb commit 57010a3

File tree

3 files changed

+89
-88
lines changed

3 files changed

+89
-88
lines changed

samples/model-builder/create_custom_job_psci_sample.py

Lines changed: 31 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
# [START aiplatform_sdk_create_custom_job_psci_sample]
1616
from google.cloud import aiplatform
17-
from google.cloud import aiplatform_v1beta1
1817

1918

2019
def create_custom_job_psci_sample(
@@ -26,40 +25,40 @@ def create_custom_job_psci_sample(
2625
replica_count: int,
2726
image_uri: str,
2827
network_attachment: str,
28+
domain: str,
29+
target_project: str,
30+
target_network: str,
2931
):
30-
"""Custom training job sample with PSC-I through aiplatform_v1beta1."""
31-
aiplatform.init(project=project, location=location, staging_bucket=bucket)
32+
"""Custom training job sample with PSC Interface Config."""
33+
aiplatform.init(project=project, location=location, staging_bucket=bucket)
3234

33-
client_options = {"api_endpoint": f"{location}-aiplatform.googleapis.com"}
35+
worker_pool_specs = [{
36+
"machine_spec": {
37+
"machine_type": machine_type,
38+
},
39+
"replica_count": replica_count,
40+
"container_spec": {
41+
"image_uri": image_uri,
42+
"command": [],
43+
"args": [],
44+
},
45+
}]
46+
psc_interface_config = {
47+
"network_attachment": network_attachment,
48+
"dns_peering_configs": [
49+
{
50+
"domain": domain,
51+
"target_project": target_project,
52+
"target_network": target_network,
53+
},
54+
],
55+
}
56+
job = aiplatform.CustomJob(
57+
display_name=display_name,
58+
worker_pool_specs=worker_pool_specs,
59+
)
3460

35-
client = aiplatform_v1beta1.JobServiceClient(client_options=client_options)
36-
37-
request = aiplatform_v1beta1.CreateCustomJobRequest(
38-
parent=f"projects/{project}/locations/{location}",
39-
custom_job=aiplatform_v1beta1.CustomJob(
40-
display_name=display_name,
41-
job_spec=aiplatform_v1beta1.CustomJobSpec(
42-
worker_pool_specs=[
43-
aiplatform_v1beta1.WorkerPoolSpec(
44-
machine_spec=aiplatform_v1beta1.MachineSpec(
45-
machine_type=machine_type,
46-
),
47-
replica_count=replica_count,
48-
container_spec=aiplatform_v1beta1.ContainerSpec(
49-
image_uri=image_uri,
50-
),
51-
)
52-
],
53-
psc_interface_config=aiplatform_v1beta1.PscInterfaceConfig(
54-
network_attachment=network_attachment,
55-
),
56-
),
57-
),
58-
)
59-
60-
response = client.create_custom_job(request=request)
61-
62-
return response
61+
job.run(psc_interface_config=psc_interface_config)
6362

6463

6564
# [END aiplatform_sdk_create_custom_job_psci_sample]

samples/model-builder/create_custom_job_psci_sample_test.py

Lines changed: 31 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -13,65 +13,40 @@
1313
# limitations under the License.
1414

1515
import create_custom_job_psci_sample
16-
from google.cloud import aiplatform_v1beta1
1716
import test_constants as constants
1817

1918

2019
def test_create_custom_job_psci_sample(
2120
mock_sdk_init,
22-
mock_get_job_service_client_v1beta1,
23-
mock_get_create_custom_job_request_v1beta1,
24-
mock_create_custom_job_v1beta1,
21+
mock_get_custom_job,
22+
mock_run_custom_job,
2523
):
26-
"""Custom training job sample with PSC-I through aiplatform_v1beta1."""
27-
create_custom_job_psci_sample.create_custom_job_psci_sample(
28-
project=constants.PROJECT,
29-
location=constants.LOCATION,
30-
bucket=constants.STAGING_BUCKET,
31-
display_name=constants.DISPLAY_NAME,
32-
machine_type=constants.MACHINE_TYPE,
33-
replica_count=1,
34-
image_uri=constants.TRAIN_IMAGE,
35-
network_attachment=constants.NETWORK_ATTACHMENT_NAME,
36-
)
37-
38-
mock_sdk_init.assert_called_once_with(
39-
project=constants.PROJECT,
40-
location=constants.LOCATION,
41-
staging_bucket=constants.STAGING_BUCKET,
42-
)
43-
44-
mock_get_job_service_client_v1beta1.assert_called_once_with(
45-
client_options={
46-
"api_endpoint": f"{constants.LOCATION}-aiplatform.googleapis.com"
47-
}
48-
)
49-
50-
mock_get_create_custom_job_request_v1beta1.assert_called_once_with(
51-
parent=f"projects/{constants.PROJECT}/locations/{constants.LOCATION}",
52-
custom_job=aiplatform_v1beta1.CustomJob(
53-
display_name=constants.DISPLAY_NAME,
54-
job_spec=aiplatform_v1beta1.CustomJobSpec(
55-
worker_pool_specs=[
56-
aiplatform_v1beta1.WorkerPoolSpec(
57-
machine_spec=aiplatform_v1beta1.MachineSpec(
58-
machine_type=constants.MACHINE_TYPE,
59-
),
60-
replica_count=constants.REPLICA_COUNT,
61-
container_spec=aiplatform_v1beta1.ContainerSpec(
62-
image_uri=constants.TRAIN_IMAGE,
63-
),
64-
)
65-
],
66-
psc_interface_config=aiplatform_v1beta1.PscInterfaceConfig(
67-
network_attachment=constants.NETWORK_ATTACHMENT_NAME,
68-
),
69-
),
70-
),
71-
)
72-
73-
request = aiplatform_v1beta1.CreateCustomJobRequest(
74-
mock_get_create_custom_job_request_v1beta1.return_value
75-
)
76-
77-
mock_create_custom_job_v1beta1.assert_called_once_with(request=request)
24+
"""Custom training job sample with PSC-I through aiplatform_v1beta1."""
25+
create_custom_job_psci_sample.create_custom_job_psci_sample(
26+
project=constants.PROJECT,
27+
location=constants.LOCATION,
28+
bucket=constants.STAGING_BUCKET,
29+
display_name=constants.DISPLAY_NAME,
30+
machine_type=constants.MACHINE_TYPE,
31+
replica_count=1,
32+
image_uri=constants.CONTAINER_URI,
33+
network_attachment=constants.NETWORK_ATTACHMENT_NAME,
34+
domain=constants.DOMAIN,
35+
target_project=constants.TARGET_PROJECT,
36+
target_network=constants.TARGET_NETWORK,
37+
)
38+
39+
mock_sdk_init.assert_called_once_with(
40+
project=constants.PROJECT,
41+
location=constants.LOCATION,
42+
staging_bucket=constants.STAGING_BUCKET,
43+
)
44+
45+
mock_get_custom_job.assert_called_once_with(
46+
display_name=constants.DISPLAY_NAME,
47+
worker_pool_specs=constants.CUSTOM_JOB_WORKER_POOL_SPECS_WITHOUT_ACCELERATOR,
48+
)
49+
50+
mock_run_custom_job.assert_called_once_with(
51+
psc_interface_config=constants.PSC_INTERFACE_CONFIG,
52+
)

samples/model-builder/test_constants.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,19 @@
116116
ACCELERATOR_TYPE = "ACCELERATOR_TYPE_UNSPECIFIED"
117117
ACCELERATOR_COUNT = 0
118118
NETWORK_ATTACHMENT_NAME = "network-attachment-name"
119+
DOMAIN = "test.com"
120+
TARGET_PROJECT = "target-project"
121+
TARGET_NETWORK = "target-network"
122+
PSC_INTERFACE_CONFIG = {
123+
"network_attachment": NETWORK_ATTACHMENT_NAME,
124+
"dns_peering_configs": [
125+
{
126+
"domain": DOMAIN,
127+
"target_project": TARGET_PROJECT,
128+
"target_network": TARGET_NETWORK,
129+
},
130+
],
131+
}
119132

120133
# Model constants
121134
MODEL_RESOURCE_NAME = f"{PARENT}/models/1234"
@@ -387,6 +400,20 @@
387400
}
388401
]
389402

403+
CUSTOM_JOB_WORKER_POOL_SPECS_WITHOUT_ACCELERATOR = [
404+
{
405+
"machine_spec": {
406+
"machine_type": "n1-standard-4",
407+
},
408+
"replica_count": 1,
409+
"container_spec": {
410+
"image_uri": CONTAINER_URI,
411+
"command": [],
412+
"args": [],
413+
},
414+
}
415+
]
416+
390417
VERSION_ID = "test-version"
391418
IS_DEFAULT_VERSION = False
392419
VERSION_ALIASES = ["test-version-alias"]

0 commit comments

Comments
 (0)