Skip to content

Commit fe91b17

Browse files
committed
add: Support for image pull secrets for Ray Cluster images
1 parent df48547 commit fe91b17

File tree

3 files changed

+23
-0
lines changed

3 files changed

+23
-0
lines changed

src/codeflare_sdk/cluster/cluster.py

+2
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ def create_app_wrapper(self):
8585
instance_types = self.config.machine_types
8686
env = self.config.envs
8787
local_interactive = self.config.local_interactive
88+
image_pull_secrets = self.config.image_pull_secrets
8889
return generate_appwrapper(
8990
name=name,
9091
namespace=namespace,
@@ -100,6 +101,7 @@ def create_app_wrapper(self):
100101
instance_types=instance_types,
101102
env=env,
102103
local_interactive=local_interactive,
104+
image_pull_secrets=image_pull_secrets,
103105
)
104106

105107
# creates a new cluster with the provided or default spec

src/codeflare_sdk/cluster/config.py

+1
Original file line numberDiff line numberDiff line change
@@ -49,3 +49,4 @@ class ClusterConfiguration:
4949
envs: dict = field(default_factory=dict)
5050
image: str = "ghcr.io/foundation-model-stack/base:ray2.1.0-py38-gpu-pytorch1.12.0cu116-20221213-193103"
5151
local_interactive: bool = False
52+
image_pull_secrets: list = field(default_factory=list)

src/codeflare_sdk/utils/generate_yaml.py

+20
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,14 @@ def update_image(spec, image):
141141
container["image"] = image
142142

143143

144+
def update_image_pull_secrets(spec, image_pull_secrets):
145+
if image_pull_secrets:
146+
if "imagePullSecrets" not in spec:
147+
spec["imagePullSecrets"] = []
148+
for image_pull_secret in image_pull_secrets:
149+
spec["imagePullSecrets"].append({"name": image_pull_secret})
150+
151+
144152
def update_env(spec, env):
145153
containers = spec.get("containers")
146154
for container in containers:
@@ -178,6 +186,7 @@ def update_nodes(
178186
image,
179187
instascale,
180188
env,
189+
image_pull_secrets,
181190
):
182191
if "generictemplate" in item.keys():
183192
head = item.get("generictemplate").get("spec").get("headGroupSpec")
@@ -193,6 +202,7 @@ def update_nodes(
193202
for comp in [head, worker]:
194203
spec = comp.get("template").get("spec")
195204
update_affinity(spec, appwrapper_name, instascale)
205+
update_image_pull_secrets(spec, image_pull_secrets)
196206
update_image(spec, image)
197207
update_env(spec, env)
198208
if comp == head:
@@ -295,6 +305,7 @@ def generate_appwrapper(
295305
instance_types: list,
296306
env,
297307
local_interactive: bool,
308+
image_pull_secrets: list,
298309
):
299310
user_yaml = read_template(template)
300311
appwrapper_name, cluster_name = gen_names(name)
@@ -318,6 +329,7 @@ def generate_appwrapper(
318329
image,
319330
instascale,
320331
env,
332+
image_pull_secrets,
321333
)
322334
update_dashboard_route(route_item, cluster_name, namespace)
323335
if local_interactive:
@@ -409,6 +421,12 @@ def main(): # pragma: no cover
409421
default=False,
410422
help="Enable local interactive mode",
411423
)
424+
parser.add_argument(
425+
"--image-pull-secrets",
426+
required=False,
427+
default=[],
428+
help="Set image pull secrets for private registries",
429+
)
412430

413431
args = parser.parse_args()
414432
name = args.name
@@ -425,6 +443,7 @@ def main(): # pragma: no cover
425443
namespace = args.namespace
426444
local_interactive = args.local_interactive
427445
env = {}
446+
image_pull_secrets = args.image_pull_secrets
428447

429448
outfile = generate_appwrapper(
430449
name,
@@ -441,6 +460,7 @@ def main(): # pragma: no cover
441460
instance_types,
442461
local_interactive,
443462
env,
463+
image_pull_secrets,
444464
)
445465
return outfile
446466

0 commit comments

Comments
 (0)