Skip to content
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions src/zenml/config/server_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# permissions and limitations under the License.
"""Functionality to support ZenML GlobalConfiguration."""

import json
import os
from secrets import token_hex
from typing import Any, Dict, List, Optional
Expand Down Expand Up @@ -52,6 +53,7 @@ class ServerConfiguration(BaseModel):

Attributes:
deployment_type: The type of ZenML server deployment that is running.
base_url: The base URL of the ZenML server.
root_url_path: The root URL path of the ZenML server.
auth_scheme: The authentication scheme used by the ZenML server.
jwt_token_algorithm: The algorithm used to sign and verify JWT tokens.
Expand Down Expand Up @@ -107,17 +109,22 @@ class ServerConfiguration(BaseModel):
external_server_id: The ID of the ZenML server to use with the
`EXTERNAL` authentication scheme. If not specified, the regular
ZenML server ID is used.
metadata: Additional metadata to be associated with the ZenML server.
rbac_implementation_source: Source pointing to a class implementing
the RBAC interface defined by
`zenml.zen_server.rbac_interface.RBACInterface`. If not specified,
RBAC will not be enabled for this server.
workload_manager_implementation_source: Source pointing to a class
implementing the workload management interface.
pipeline_run_auth_window: The default time window in minutes for which
a pipeline run action is allowed to authenticate with the ZenML
server.
"""

deployment_type: ServerDeploymentType = ServerDeploymentType.OTHER
base_url: str = ""
root_url_path: str = ""
metadata: Dict[str, Any] = {}
auth_scheme: AuthScheme = AuthScheme.OAUTH2_PASSWORD_BEARER
jwt_token_algorithm: str = DEFAULT_ZENML_JWT_TOKEN_ALGORITHM
jwt_token_issuer: Optional[str] = None
Expand Down Expand Up @@ -191,6 +198,15 @@ def _validate_config(cls, values: Dict[str, Any]) -> Dict[str, Any]:
else:
values["cors_allow_origins"] = ["*"]

# if metadata is a string, convert it to a dictionary
if isinstance(values.get("metadata"), str):
try:
values["metadata"] = json.loads(values["metadata"])
except json.JSONDecodeError as e:
raise ValueError(
f"The server metadata is not a valid JSON string: {e}"
)

return values

@property
Expand Down
9 changes: 9 additions & 0 deletions src/zenml/models/v2/misc/server_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# permissions and limitations under the License.
"""Model definitions for ZenML servers."""

from typing import Dict
from uuid import UUID, uuid4

from pydantic import BaseModel, Field
Expand Down Expand Up @@ -73,6 +74,14 @@ class ServerModel(BaseModel):
auth_scheme: AuthScheme = Field(
title="The authentication scheme that the server is using.",
)
base_url: str = Field(
"",
title="The Base URL of the server.",
)
metadata: Dict[str, str] = Field(
{},
title="The metadata associated with the server.",
)

def is_local(self) -> bool:
"""Return whether the server is running locally.
Expand Down
3 changes: 3 additions & 0 deletions src/zenml/zen_server/deploy/helm/templates/_environment.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ external_server_id: {{ .ZenML.auth.externalServerID | quote }}
{{- if .ZenML.rootUrlPath }}
root_url_path: {{ .ZenML.rootUrlPath | quote }}
{{- end }}
{{- if .ZenML.baseURL }}
base_url: {{ .ZenML.baseURL | quote }}
{{- end }}
{{- if .ZenML.auth.rbacImplementationSource }}
rbac_implementation_source: {{ .ZenML.auth.rbacImplementationSource | quote }}
{{- end }}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,10 @@ spec:
value: {{ $value | quote }}
{{- end }}
{{- end }}
{{- if .Values.zenml.baseURL }}
- name: ZENML_SERVER_BASE_URL
value: {{ .Values.zenml.baseURL | quote }}
{{- end }}
envFrom:
- secretRef:
name: {{ include "zenml.fullname" . }}
Expand Down
4 changes: 4 additions & 0 deletions src/zenml/zen_server/deploy/helm/values.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ zenml:
# Overrides the image tag whose default is the chart appVersion.
tag:

# The URL of the ZenML server. This is used to direct users to the correct
# address in log messages when running a pipeline and for other similar tasks.
baseURL:

debug: true

# Flag to enable/disable the tracking process of the analytics
Expand Down
4 changes: 4 additions & 0 deletions src/zenml/zen_stores/base_zen_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,8 @@ def get_store_info(self) -> ServerModel:
server_config = ServerConfiguration.get_server_config()
deployment_type = server_config.deployment_type
auth_scheme = server_config.auth_scheme
base_url = server_config.base_url
metadata = server_config.metadata
secrets_store_type = SecretsStoreType.NONE
if isinstance(self, SqlZenStore):
secrets_store_type = self.secrets_store.type
Expand All @@ -386,6 +388,8 @@ def get_store_info(self) -> ServerModel:
debug=IS_DEBUG_ENV,
secrets_store_type=secrets_store_type,
auth_scheme=auth_scheme,
base_url=base_url,
metadata=metadata,
)

def is_local_store(self) -> bool:
Expand Down
1 change: 0 additions & 1 deletion src/zenml/zen_stores/sql_zen_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -1409,7 +1409,6 @@ def get_store_info(self) -> ServerModel:
# Fetch the deployment ID from the database and use it to replace
# the one fetched from the global configuration
model.id = self.get_deployment_id()

return model

def get_deployment_id(self) -> UUID:
Expand Down