Skip to content

Commit 683e943

Browse files
wjayeshactions-userschustmi
authored
Add option to add base URL for zenml server (with support for cloud) (#2464)
* add base url as helm input and create env var * add server url as part of zenml info endpoint * Auto-update of Starter template * make org and tenant id available separately * add base url and org id to server config * add base url and org id to server model in base zen store * add base URL to helm template for env vars * make cloud org ID opauqe and part of server metadata * make metadata a Dict type object * if metadata is str, convert to dict * fix linter issue * make empty dict as default metadata value * Add server metadata to analytics --------- Co-authored-by: GitHub Actions <[email protected]> Co-authored-by: Michael Schuster <[email protected]>
1 parent afcaf74 commit 683e943

File tree

8 files changed

+45
-1
lines changed

8 files changed

+45
-1
lines changed

src/zenml/analytics/context.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ def __init__(self) -> None:
6060
self.client_id: Optional[UUID] = None
6161
self.server_id: Optional[UUID] = None
6262
self.external_server_id: Optional[UUID] = None
63+
self.server_metadata: Optional[Dict[str, str]] = None
6364

6465
self.database_type: Optional["ServerDatabaseType"] = None
6566
self.deployment_type: Optional["ServerDeploymentType"] = None
@@ -118,6 +119,7 @@ def __enter__(self) -> "AnalyticsContext":
118119
self.server_id = store_info.id
119120
self.deployment_type = store_info.deployment_type
120121
self.database_type = store_info.database_type
122+
self.server_metadata = store_info.metadata
121123
except Exception as e:
122124
self.analytics_opt_in = False
123125
logger.debug(f"Analytics initialization failed: {e}")
@@ -272,6 +274,9 @@ def track(
272274
if self.external_server_id:
273275
properties["external_server_id"] = self.external_server_id
274276

277+
if self.server_metadata:
278+
properties.update(self.server_metadata)
279+
275280
for k, v in properties.items():
276281
if isinstance(v, UUID):
277282
properties[k] = str(v)

src/zenml/config/server_config.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# permissions and limitations under the License.
1414
"""Functionality to support ZenML GlobalConfiguration."""
1515

16+
import json
1617
import os
1718
from secrets import token_hex
1819
from typing import Any, Dict, List, Optional
@@ -52,6 +53,7 @@ class ServerConfiguration(BaseModel):
5253
5354
Attributes:
5455
deployment_type: The type of ZenML server deployment that is running.
56+
base_url: The base URL of the ZenML server.
5557
root_url_path: The root URL path of the ZenML server.
5658
auth_scheme: The authentication scheme used by the ZenML server.
5759
jwt_token_algorithm: The algorithm used to sign and verify JWT tokens.
@@ -107,17 +109,22 @@ class ServerConfiguration(BaseModel):
107109
external_server_id: The ID of the ZenML server to use with the
108110
`EXTERNAL` authentication scheme. If not specified, the regular
109111
ZenML server ID is used.
112+
metadata: Additional metadata to be associated with the ZenML server.
110113
rbac_implementation_source: Source pointing to a class implementing
111114
the RBAC interface defined by
112115
`zenml.zen_server.rbac_interface.RBACInterface`. If not specified,
113116
RBAC will not be enabled for this server.
117+
workload_manager_implementation_source: Source pointing to a class
118+
implementing the workload management interface.
114119
pipeline_run_auth_window: The default time window in minutes for which
115120
a pipeline run action is allowed to authenticate with the ZenML
116121
server.
117122
"""
118123

119124
deployment_type: ServerDeploymentType = ServerDeploymentType.OTHER
125+
base_url: str = ""
120126
root_url_path: str = ""
127+
metadata: Dict[str, Any] = {}
121128
auth_scheme: AuthScheme = AuthScheme.OAUTH2_PASSWORD_BEARER
122129
jwt_token_algorithm: str = DEFAULT_ZENML_JWT_TOKEN_ALGORITHM
123130
jwt_token_issuer: Optional[str] = None
@@ -191,6 +198,15 @@ def _validate_config(cls, values: Dict[str, Any]) -> Dict[str, Any]:
191198
else:
192199
values["cors_allow_origins"] = ["*"]
193200

201+
# if metadata is a string, convert it to a dictionary
202+
if isinstance(values.get("metadata"), str):
203+
try:
204+
values["metadata"] = json.loads(values["metadata"])
205+
except json.JSONDecodeError as e:
206+
raise ValueError(
207+
f"The server metadata is not a valid JSON string: {e}"
208+
)
209+
194210
return values
195211

196212
@property

src/zenml/models/v2/misc/server_models.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# permissions and limitations under the License.
1414
"""Model definitions for ZenML servers."""
1515

16+
from typing import Dict
1617
from uuid import UUID, uuid4
1718

1819
from pydantic import BaseModel, Field
@@ -73,6 +74,14 @@ class ServerModel(BaseModel):
7374
auth_scheme: AuthScheme = Field(
7475
title="The authentication scheme that the server is using.",
7576
)
77+
base_url: str = Field(
78+
"",
79+
title="The Base URL of the server.",
80+
)
81+
metadata: Dict[str, str] = Field(
82+
{},
83+
title="The metadata associated with the server.",
84+
)
7685

7786
def is_local(self) -> bool:
7887
"""Return whether the server is running locally.

src/zenml/zen_server/deploy/helm/templates/_environment.tpl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,9 @@ external_server_id: {{ .ZenML.auth.externalServerID | quote }}
8282
{{- if .ZenML.rootUrlPath }}
8383
root_url_path: {{ .ZenML.rootUrlPath | quote }}
8484
{{- end }}
85+
{{- if .ZenML.baseURL }}
86+
base_url: {{ .ZenML.baseURL | quote }}
87+
{{- end }}
8588
{{- if .ZenML.auth.rbacImplementationSource }}
8689
rbac_implementation_source: {{ .ZenML.auth.rbacImplementationSource | quote }}
8790
{{- end }}

src/zenml/zen_server/deploy/helm/templates/server-deployment.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,10 @@ spec:
9191
value: {{ $value | quote }}
9292
{{- end }}
9393
{{- end }}
94+
{{- if .Values.zenml.baseURL }}
95+
- name: ZENML_SERVER_BASE_URL
96+
value: {{ .Values.zenml.baseURL | quote }}
97+
{{- end }}
9498
envFrom:
9599
- secretRef:
96100
name: {{ include "zenml.fullname" . }}

src/zenml/zen_server/deploy/helm/values.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@ zenml:
1616
# Overrides the image tag whose default is the chart appVersion.
1717
tag:
1818

19+
# The URL of the ZenML server. This is used to direct users to the correct
20+
# address in log messages when running a pipeline and for other similar tasks.
21+
baseURL:
22+
1923
debug: true
2024

2125
# Flag to enable/disable the tracking process of the analytics

src/zenml/zen_stores/base_zen_store.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,8 @@ def get_store_info(self) -> ServerModel:
375375
server_config = ServerConfiguration.get_server_config()
376376
deployment_type = server_config.deployment_type
377377
auth_scheme = server_config.auth_scheme
378+
base_url = server_config.base_url
379+
metadata = server_config.metadata
378380
secrets_store_type = SecretsStoreType.NONE
379381
if isinstance(self, SqlZenStore):
380382
secrets_store_type = self.secrets_store.type
@@ -386,6 +388,8 @@ def get_store_info(self) -> ServerModel:
386388
debug=IS_DEBUG_ENV,
387389
secrets_store_type=secrets_store_type,
388390
auth_scheme=auth_scheme,
391+
base_url=base_url,
392+
metadata=metadata,
389393
)
390394

391395
def is_local_store(self) -> bool:

src/zenml/zen_stores/sql_zen_store.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1409,7 +1409,6 @@ def get_store_info(self) -> ServerModel:
14091409
# Fetch the deployment ID from the database and use it to replace
14101410
# the one fetched from the global configuration
14111411
model.id = self.get_deployment_id()
1412-
14131412
return model
14141413

14151414
def get_deployment_id(self) -> UUID:

0 commit comments

Comments
 (0)