|
| 1 | +/* |
| 2 | + * Copyright 2020 Google LLC |
| 3 | + * |
| 4 | + * Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | + * you may not use this file except in compliance with the License. |
| 6 | + * You may obtain a copy of the License at |
| 7 | + * |
| 8 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | + * |
| 10 | + * Unless required by applicable law or agreed to in writing, software |
| 11 | + * distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | + * See the License for the specific language governing permissions and |
| 14 | + * limitations under the License. |
| 15 | + */ |
| 16 | + |
| 17 | +package aiplatform; |
| 18 | + |
| 19 | +// [START aiplatform_create_training_pipeline_sample] |
| 20 | + |
| 21 | +import com.google.cloud.aiplatform.v1beta1.DeployedModelRef; |
| 22 | +import com.google.cloud.aiplatform.v1beta1.EnvVar; |
| 23 | +import com.google.cloud.aiplatform.v1beta1.ExplanationMetadata; |
| 24 | +import com.google.cloud.aiplatform.v1beta1.ExplanationParameters; |
| 25 | +import com.google.cloud.aiplatform.v1beta1.ExplanationSpec; |
| 26 | +import com.google.cloud.aiplatform.v1beta1.FilterSplit; |
| 27 | +import com.google.cloud.aiplatform.v1beta1.FractionSplit; |
| 28 | +import com.google.cloud.aiplatform.v1beta1.InputDataConfig; |
| 29 | +import com.google.cloud.aiplatform.v1beta1.LocationName; |
| 30 | +import com.google.cloud.aiplatform.v1beta1.Model; |
| 31 | +import com.google.cloud.aiplatform.v1beta1.Model.ExportFormat; |
| 32 | +import com.google.cloud.aiplatform.v1beta1.ModelContainerSpec; |
| 33 | +import com.google.cloud.aiplatform.v1beta1.PipelineServiceClient; |
| 34 | +import com.google.cloud.aiplatform.v1beta1.PipelineServiceSettings; |
| 35 | +import com.google.cloud.aiplatform.v1beta1.Port; |
| 36 | +import com.google.cloud.aiplatform.v1beta1.PredefinedSplit; |
| 37 | +import com.google.cloud.aiplatform.v1beta1.PredictSchemata; |
| 38 | +import com.google.cloud.aiplatform.v1beta1.SampledShapleyAttribution; |
| 39 | +import com.google.cloud.aiplatform.v1beta1.TimestampSplit; |
| 40 | +import com.google.cloud.aiplatform.v1beta1.TrainingPipeline; |
| 41 | +import com.google.protobuf.Any; |
| 42 | +import com.google.protobuf.Value; |
| 43 | +import com.google.protobuf.util.JsonFormat; |
| 44 | +import com.google.rpc.Status; |
| 45 | +import java.io.IOException; |
| 46 | +import java.util.List; |
| 47 | + |
| 48 | +public class CreateTrainingPipelineSample { |
| 49 | + |
| 50 | + public static void main(String[] args) throws IOException { |
| 51 | + // TODO(developer): Replace these variables before running the sample. |
| 52 | + String trainingPipelineDisplayName = "YOUR_TRAINING_PIPELINE_DISPLAY_NAME"; |
| 53 | + String project = "YOUR_PROJECT_ID"; |
| 54 | + String datasetId = "YOUR_DATASET_ID"; |
| 55 | + String trainingTaskDefinition = "YOUR_TRAINING_TASK_DEFINITION"; |
| 56 | + String modelDisplayName = "YOUR_MODEL_DISPLAY_NAME"; |
| 57 | + createTrainingPipelineSample( |
| 58 | + project, trainingPipelineDisplayName, datasetId, trainingTaskDefinition, modelDisplayName); |
| 59 | + } |
| 60 | + |
| 61 | + static void createTrainingPipelineSample( |
| 62 | + String project, |
| 63 | + String trainingPipelineDisplayName, |
| 64 | + String datasetId, |
| 65 | + String trainingTaskDefinition, |
| 66 | + String modelDisplayName) |
| 67 | + throws IOException { |
| 68 | + PipelineServiceSettings pipelineServiceSettings = |
| 69 | + PipelineServiceSettings.newBuilder() |
| 70 | + .setEndpoint("us-central1-aiplatform.googleapis.com:443") |
| 71 | + .build(); |
| 72 | + |
| 73 | + // Initialize client that will be used to send requests. This client only needs to be created |
| 74 | + // once, and can be reused for multiple requests. After completing all of your requests, call |
| 75 | + // the "close" method on the client to safely clean up any remaining background resources. |
| 76 | + try (PipelineServiceClient pipelineServiceClient = |
| 77 | + PipelineServiceClient.create(pipelineServiceSettings)) { |
| 78 | + String location = "us-central1"; |
| 79 | + LocationName locationName = LocationName.of(project, location); |
| 80 | + |
| 81 | + String jsonString = |
| 82 | + "{\"multiLabel\": false, \"modelType\": \"CLOUD\", \"budgetMilliNodeHours\": 8000," |
| 83 | + + " \"disableEarlyStopping\": false}"; |
| 84 | + Value.Builder trainingTaskInputs = Value.newBuilder(); |
| 85 | + JsonFormat.parser().merge(jsonString, trainingTaskInputs); |
| 86 | + |
| 87 | + InputDataConfig trainingInputDataConfig = |
| 88 | + InputDataConfig.newBuilder().setDatasetId(datasetId).build(); |
| 89 | + Model model = Model.newBuilder().setDisplayName(modelDisplayName).build(); |
| 90 | + TrainingPipeline trainingPipeline = |
| 91 | + TrainingPipeline.newBuilder() |
| 92 | + .setDisplayName(trainingPipelineDisplayName) |
| 93 | + .setTrainingTaskDefinition(trainingTaskDefinition) |
| 94 | + .setTrainingTaskInputs(trainingTaskInputs) |
| 95 | + .setInputDataConfig(trainingInputDataConfig) |
| 96 | + .setModelToUpload(model) |
| 97 | + .build(); |
| 98 | + |
| 99 | + TrainingPipeline trainingPipelineResponse = |
| 100 | + pipelineServiceClient.createTrainingPipeline(locationName, trainingPipeline); |
| 101 | + |
| 102 | + System.out.println("Create Training Pipeline Response"); |
| 103 | + System.out.format("Name: %s\n", trainingPipelineResponse.getName()); |
| 104 | + System.out.format("Display Name: %s\n", trainingPipelineResponse.getDisplayName()); |
| 105 | + |
| 106 | + System.out.format( |
| 107 | + "Training Task Definition %s\n", trainingPipelineResponse.getTrainingTaskDefinition()); |
| 108 | + System.out.format( |
| 109 | + "Training Task Inputs: %s\n", trainingPipelineResponse.getTrainingTaskInputs()); |
| 110 | + System.out.format( |
| 111 | + "Training Task Metadata: %s\n", trainingPipelineResponse.getTrainingTaskMetadata()); |
| 112 | + System.out.format("State: %s\n", trainingPipelineResponse.getState()); |
| 113 | + |
| 114 | + System.out.format("Create Time: %s\n", trainingPipelineResponse.getCreateTime()); |
| 115 | + System.out.format("StartTime %s\n", trainingPipelineResponse.getStartTime()); |
| 116 | + System.out.format("End Time: %s\n", trainingPipelineResponse.getEndTime()); |
| 117 | + System.out.format("Update Time: %s\n", trainingPipelineResponse.getUpdateTime()); |
| 118 | + System.out.format("Labels: %s\n", trainingPipelineResponse.getLabelsMap()); |
| 119 | + |
| 120 | + InputDataConfig inputDataConfig = trainingPipelineResponse.getInputDataConfig(); |
| 121 | + System.out.println("Input Data Config"); |
| 122 | + System.out.format("Dataset Id: %s", inputDataConfig.getDatasetId()); |
| 123 | + System.out.format("Annotations Filter: %s\n", inputDataConfig.getAnnotationsFilter()); |
| 124 | + |
| 125 | + FractionSplit fractionSplit = inputDataConfig.getFractionSplit(); |
| 126 | + System.out.println("Fraction Split"); |
| 127 | + System.out.format("Training Fraction: %s\n", fractionSplit.getTrainingFraction()); |
| 128 | + System.out.format("Validation Fraction: %s\n", fractionSplit.getValidationFraction()); |
| 129 | + System.out.format("Test Fraction: %s\n", fractionSplit.getTestFraction()); |
| 130 | + |
| 131 | + FilterSplit filterSplit = inputDataConfig.getFilterSplit(); |
| 132 | + System.out.println("Filter Split"); |
| 133 | + System.out.format("Training Filter: %s\n", filterSplit.getTrainingFilter()); |
| 134 | + System.out.format("Validation Filter: %s\n", filterSplit.getValidationFilter()); |
| 135 | + System.out.format("Test Filter: %s\n", filterSplit.getTestFilter()); |
| 136 | + |
| 137 | + PredefinedSplit predefinedSplit = inputDataConfig.getPredefinedSplit(); |
| 138 | + System.out.println("Predefined Split"); |
| 139 | + System.out.format("Key: %s\n", predefinedSplit.getKey()); |
| 140 | + |
| 141 | + TimestampSplit timestampSplit = inputDataConfig.getTimestampSplit(); |
| 142 | + System.out.println("Timestamp Split"); |
| 143 | + System.out.format("Training Fraction: %s\n", timestampSplit.getTrainingFraction()); |
| 144 | + System.out.format("Validation Fraction: %s\n", timestampSplit.getValidationFraction()); |
| 145 | + System.out.format("Test Fraction: %s\n", timestampSplit.getTestFraction()); |
| 146 | + System.out.format("Key: %s\n", timestampSplit.getKey()); |
| 147 | + |
| 148 | + Model modelResponse = trainingPipelineResponse.getModelToUpload(); |
| 149 | + System.out.println("Model To Upload"); |
| 150 | + System.out.format("Name: %s\n", modelResponse.getName()); |
| 151 | + System.out.format("Display Name: %s\n", modelResponse.getDisplayName()); |
| 152 | + System.out.format("Description: %s\n", modelResponse.getDescription()); |
| 153 | + |
| 154 | + System.out.format("Metadata Schema Uri: %s\n", modelResponse.getMetadataSchemaUri()); |
| 155 | + System.out.format("Metadata: %s\n", modelResponse.getMetadata()); |
| 156 | + System.out.format("Training Pipeline: %s\n", modelResponse.getTrainingPipeline()); |
| 157 | + System.out.format("Artifact Uri: %s\n", modelResponse.getArtifactUri()); |
| 158 | + |
| 159 | + System.out.format( |
| 160 | + "Supported Deployment Resources Types: %s\n", |
| 161 | + modelResponse.getSupportedDeploymentResourcesTypesList()); |
| 162 | + System.out.format( |
| 163 | + "Supported Input Storage Formats: %s\n", |
| 164 | + modelResponse.getSupportedInputStorageFormatsList()); |
| 165 | + System.out.format( |
| 166 | + "Supported Output Storage Formats: %s\n", |
| 167 | + modelResponse.getSupportedOutputStorageFormatsList()); |
| 168 | + |
| 169 | + System.out.format("Create Time: %s\n", modelResponse.getCreateTime()); |
| 170 | + System.out.format("Update Time: %s\n", modelResponse.getUpdateTime()); |
| 171 | + System.out.format("Labels: %sn\n", modelResponse.getLabelsMap()); |
| 172 | + |
| 173 | + PredictSchemata predictSchemata = modelResponse.getPredictSchemata(); |
| 174 | + System.out.println("Predict Schemata"); |
| 175 | + System.out.format("Instance Schema Uri: %s\n", predictSchemata.getInstanceSchemaUri()); |
| 176 | + System.out.format("Parameters Schema Uri: %s\n", predictSchemata.getParametersSchemaUri()); |
| 177 | + System.out.format("Prediction Schema Uri: %s\n", predictSchemata.getPredictionSchemaUri()); |
| 178 | + |
| 179 | + for (ExportFormat exportFormat : modelResponse.getSupportedExportFormatsList()) { |
| 180 | + System.out.println("Supported Export Format"); |
| 181 | + System.out.format("Id: %s\n", exportFormat.getId()); |
| 182 | + } |
| 183 | + |
| 184 | + ModelContainerSpec modelContainerSpec = modelResponse.getContainerSpec(); |
| 185 | + System.out.println("Container Spec"); |
| 186 | + System.out.format("Image Uri: %s\n", modelContainerSpec.getImageUri()); |
| 187 | + System.out.format("Command: %s\n", modelContainerSpec.getCommandList()); |
| 188 | + System.out.format("Args: %s\n", modelContainerSpec.getArgsList()); |
| 189 | + System.out.format("Predict Route: %s\n", modelContainerSpec.getPredictRoute()); |
| 190 | + System.out.format("Health Route: %s\n", modelContainerSpec.getHealthRoute()); |
| 191 | + |
| 192 | + for (EnvVar envVar : modelContainerSpec.getEnvList()) { |
| 193 | + System.out.println("Env"); |
| 194 | + System.out.format("Name: %s\n", envVar.getName()); |
| 195 | + System.out.format("Value: %s\n", envVar.getValue()); |
| 196 | + } |
| 197 | + |
| 198 | + for (Port port : modelContainerSpec.getPortsList()) { |
| 199 | + System.out.println("Port"); |
| 200 | + System.out.format("Container Port: %s\n", port.getContainerPort()); |
| 201 | + } |
| 202 | + |
| 203 | + for (DeployedModelRef deployedModelRef : modelResponse.getDeployedModelsList()) { |
| 204 | + System.out.println("Deployed Model"); |
| 205 | + System.out.format("Endpoint: %s\n", deployedModelRef.getEndpoint()); |
| 206 | + System.out.format("Deployed Model Id: %s\n", deployedModelRef.getDeployedModelId()); |
| 207 | + } |
| 208 | + |
| 209 | + ExplanationSpec explanationSpec = modelResponse.getExplanationSpec(); |
| 210 | + System.out.println("Explanation Spec"); |
| 211 | + |
| 212 | + ExplanationParameters explanationParameters = explanationSpec.getParameters(); |
| 213 | + System.out.println("Parameters"); |
| 214 | + |
| 215 | + SampledShapleyAttribution sampledShapleyAttribution = |
| 216 | + explanationParameters.getSampledShapleyAttribution(); |
| 217 | + System.out.println("Sampled Shapley Attribution"); |
| 218 | + System.out.format("Path Count: %s\n", sampledShapleyAttribution.getPathCount()); |
| 219 | + |
| 220 | + ExplanationMetadata explanationMetadata = explanationSpec.getMetadata(); |
| 221 | + System.out.println("Metadata"); |
| 222 | + System.out.format("Inputs: %s\n", explanationMetadata.getInputsMap()); |
| 223 | + System.out.format("Outputs: %s\n", explanationMetadata.getOutputsMap()); |
| 224 | + System.out.format( |
| 225 | + "Feature Attributions Schema_uri: %s\n", |
| 226 | + explanationMetadata.getFeatureAttributionsSchemaUri()); |
| 227 | + |
| 228 | + Status status = trainingPipelineResponse.getError(); |
| 229 | + System.out.println("Error"); |
| 230 | + System.out.format("Code: %s\n", status.getCode()); |
| 231 | + System.out.format("Message: %s\n", status.getMessage()); |
| 232 | + } |
| 233 | + } |
| 234 | +} |
| 235 | +// [END aiplatform_create_training_pipeline_sample] |
0 commit comments